Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
@@ -7,109 +9,294 @@ purposes.
import argparse
import json
import ssl
from collections.abc import Sequence
from dataclasses import field
from typing import Literal
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
validate_chat_template,
)
from vllm.entrypoints.constants import (
H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.tool_parsers import ToolParserManager
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
class LoRAParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: str | Sequence[str] | None,
option_string: str | None = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
lora_list: list[LoRAModulePath] = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
if item in [None, ""]: # Skip if item is None or empty string
continue
if "=" in item and "," not in item: # Old format: name=path
name, path = item.split("=")
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
def make_arg_parser():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
@config
@dataclass
class FrontendArgs:
"""Arguments for the OpenAI-compatible frontend server."""
host: str | None = None
"""Host name."""
port: int = 8000
"""Port number."""
uds: str | None = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal[
"debug", "info", "warning", "error", "critical", "trace"
] = "info"
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
"""Allowed origins."""
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
"""Allowed methods."""
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
"""Allowed headers."""
api_key: list[str] | None = None
"""If provided, the server will require one of these keys to be presented in
the header."""
lora_modules: list[LoRAModulePath] | None = None
"""LoRA modules configurations in either 'name=path' format or JSON format
or JSON list format. Example (old format): `'name=path'` Example (new
format): `{\"name\": \"name\", \"path\": \"lora_path\",
\"base_model_name\": \"id\"}`"""
chat_template: str | None = None
"""The file path to the chat template, or the template in single-line form
for the specified model."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: str | None = None
"""The file path to the SSL key file."""
ssl_certfile: str | None = None
"""The file path to the SSL cert file."""
ssl_ca_certs: str | None = None
"""The CA certificates file."""
enable_ssl_refresh: bool = False
"""Refresh SSL Context when SSL certificate files change"""
ssl_cert_reqs: int = int(ssl.CERT_NONE)
"""Whether client certificate is required (see stdlib ssl module's)."""
root_path: str | None = None
"""FastAPI root_path when app is behind a path based routing proxy."""
middleware: list[str] = field(default_factory=lambda: [])
"""Additional ASGI middleware to apply to the app. We accept multiple
--middleware arguments. The value should be an import path. If a function
is provided, vLLM will add it to the server using
`@app.middleware('http')`. If a class is provided, vLLM will
add it to the server using `app.add_middleware()`."""
return_tokens_as_token_ids: bool = False
"""When `--max-logprobs` is specified, represents single tokens as
strings of the form 'token_id:{token_id}' so that tokens that are not
JSON-encodable can be identified."""
disable_frontend_multiprocessing: bool = False
"""If specified, will run the OpenAI frontend server in the same process as
the model serving engine."""
enable_request_id_headers: bool = False
"""If specified, API server will add X-Request-Id header to responses."""
enable_auto_tool_choice: bool = False
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
to specify which parser to use."""
exclude_tools_when_tool_choice_none: bool = False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
tool_call_parser: str | None = None
"""Select the tool call parser depending on the model that you're using.
This is used to parse the model-generated tool call into OpenAI API format.
Required for `--enable-auto-tool-choice`. You can choose any option from
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
tool_parser_plugin: str = ""
"""Special the tool parser plugin write to parse the model-generated tool
into OpenAI API format, the name register in this plugin can be used in
`--tool-call-parser`."""
tool_server: str | None = None
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
purpose."""
log_config_file: str | None = envs.VLLM_LOGGING_CONFIG_PATH
"""Path to logging config JSON file for both vllm and uvicorn"""
max_log_len: int | None = None
"""Max number of prompt characters or prompt ID numbers being printed in
log. The default of None means unlimited."""
disable_fastapi_docs: bool = False
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
enable_prompt_tokens_details: bool = False
"""If set to True, enable prompt_tokens_details in usage."""
enable_server_load_tracking: bool = False
"""If set to True, enable tracking server_load_metrics in the app state."""
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the `/tokenizer_info` endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If True, log model outputs (generations).
Requires --enable-log-requests."""
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
"""Maximum number of HTTP headers allowed in a request for h11 parser.
Helps mitigate header abuse. Default: 256."""
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
"""If set to True, log the stack trace of error responses"""
tokens_only: bool = False
"""
If set to True, only enable the Tokens In<>Out endpoint.
This is intended for use in a Disaggregated Everything setup.
"""
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
from vllm.engine.arg_utils import get_kwargs
frontend_kwargs = get_kwargs(FrontendArgs)
# Special case: allowed_origins, allowed_methods, allowed_headers all
# need json.loads type
# Should also remove nargs
frontend_kwargs["allowed_origins"]["type"] = json.loads
frontend_kwargs["allowed_methods"]["type"] = json.loads
frontend_kwargs["allowed_headers"]["type"] = json.loads
del frontend_kwargs["allowed_origins"]["nargs"]
del frontend_kwargs["allowed_methods"]["nargs"]
del frontend_kwargs["allowed_headers"]["nargs"]
# Special case: LoRA modules need custom parser action and
# optional_type(str)
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Middleware needs to append action
frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str
if "nargs" in frontend_kwargs["middleware"]:
del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = []
# Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers)
frontend_kwargs["tool_call_parser"]["metavar"] = (
f"{{{parsers_str}}} or name registered in --tool-parser-plugin"
)
frontend_group = parser.add_argument_group(
title="Frontend",
description=FrontendArgs.__doc__,
)
for key, value in frontend_kwargs.items():
frontend_group.add_argument(f"--{key.replace('_', '-')}", **value)
return parser
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Create the CLI argument parser used by the OpenAI API server.
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
register all arguments instead of manually enumerating them here. This
avoids code duplication and keeps the argument definitions in one place.
"""
parser.add_argument(
"--uvicorn-log-level",
"model_tag",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=nullable_str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=nullable_str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=nullable_str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=nullable_str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
nargs="?",
help="The model tag to serve (optional if specified in config)",
)
parser.add_argument(
"--root-path",
type=nullable_str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
"--headless",
action="store_true",
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.",
)
parser.add_argument(
"--middleware",
type=nullable_str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
"--api-server-count",
"-asc",
type=int,
default=1,
help="How many API server processes to run.",
)
parser.add_argument(
"--config",
help="Read CLI options from a config file. "
"Must be a YAML with the following options: "
"https://docs.vllm.ai/en/latest/configuration/serve_args.html",
)
parser = FrontendArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser")
if args.enable_log_outputs and not args.enable_log_requests:
raise TypeError("Error: --enable-log-outputs requires --enable-log-requests")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server"
)
return make_arg_parser(parser_for_docs)

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utility functions that create ORCA endpoint load report response headers.
"""
import json
from collections.abc import Mapping
from vllm.logger import init_logger
from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot
logger = init_logger(__name__)
def create_orca_header(
metrics_format: str, named_metrics: list[tuple[str, float]]
) -> Mapping[str, str] | None:
"""
Creates ORCA headers named 'endpoint-load-metrics' in the specified format
and adds custom metrics to named_metrics.
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
Parameters:
- metrics_format (str): The format of the header ('TEXT', 'JSON').
- named_metrics (List[Tuple[str, float]]): List of tuples with metric names
and their corresponding double values.
Returns:
- Optional[Mapping[str,str]]: A dictionary with header key as
'endpoint-load-metrics' and values as the ORCA header strings with
format prefix and data in with named_metrics in.
"""
if metrics_format.lower() not in ["text", "json"]:
logger.warning(
"Warning: `%s` format is not supported in the ORCA response header",
format,
)
return None
header = {}
orca_report = {
"named_metrics": {
metric_name: value
for metric_name, value in named_metrics
if isinstance(metric_name, str) and isinstance(value, float)
}
}
# output example:
# endpoint-load-metrics: TEXT named_metrics.kv_cache_utilization=0.4
if metrics_format.lower() == "text":
native_http_header = ", ".join(
[
f"named_metrics.{metric_name}={value}"
for metric_name, value in named_metrics
if isinstance(metric_name, str) and isinstance(value, float)
]
)
header["endpoint-load-metrics"] = f"TEXT {native_http_header}"
# output example:
# endpoint-load-metrics: JSON “named_metrics”: {“custom-metric-util”: 0.4}
elif metrics_format.lower() == "json":
header["endpoint-load-metrics"] = f"JSON {json.dumps(orca_report)}"
logger.info("Created ORCA header %s", header)
return header
def get_named_metrics_from_prometheus() -> list[tuple[str, float]]:
"""
Collects current metrics from Prometheus and returns some of them
in the form of the `named_metrics` list for `create_orca_header()`.
Parameters:
- None
Returns:
- list[tuple[str, float]]: List of tuples of metric names and their values.
"""
named_metrics: list[tuple[str, float]] = []
# Map from prometheus metric names to ORCA named metrics.
prometheus_to_orca_metrics = {
"vllm:kv_cache_usage_perc": "kv_cache_usage_perc",
"vllm:num_requests_waiting": "num_requests_waiting",
}
metrics = get_metrics_snapshot()
for metric in metrics:
orca_name = prometheus_to_orca_metrics.get(metric.name)
# If this metric is mapped into ORCA, then add it to the report.
# Note: Only Gauge metrics are currently supported.
if orca_name is not None and isinstance(metric, Gauge):
named_metrics.append((str(orca_name), float(metric.value)))
return named_metrics
def metrics_header(metrics_format: str) -> Mapping[str, str] | None:
"""
Creates ORCA headers named 'endpoint-load-metrics' in the specified format.
Metrics are collected from Prometheus using `get_named_metrics_from_prometheus()`.
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
Parameters:
- metrics_format (str): The format of the header ('TEXT', 'JSON').
Returns:
- Optional[Mapping[str,str]]: A dictionary with header key as
'endpoint-load-metrics' and values as the ORCA header strings with
format prefix and data in with named_metrics in.
"""
if not metrics_format:
return None
# Get named metrics from prometheus.
named_metrics = get_named_metrics_from_prometheus()
return create_orca_header(metrics_format, named_metrics)

View File

@@ -0,0 +1,825 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import datetime
import json
from collections.abc import Iterable, Sequence
from typing import Literal
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseOutputItem,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
)
from openai.types.responses.response_function_web_search import (
ActionFind,
ActionOpenPage,
ActionSearch,
ResponseFunctionWebSearch,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai.types.responses.tool import Tool
from openai_harmony import (
Author,
ChannelConfig,
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
StreamableParser,
SystemContent,
TextContent,
ToolDescription,
load_harmony_encoding,
)
from openai_harmony import Message as OpenAIHarmonyMessage
from openai_harmony import Role as OpenAIHarmonyRole
from vllm import envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionToolsParam,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.utils import random_uuid
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
_harmony_encoding = None
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
MCP_BUILTIN_TOOLS: set[str] = {
"web_search_preview",
"code_interpreter",
"container",
}
def has_custom_tools(tool_types: set[str]) -> bool:
"""
Checks if the given tool types are custom tools
(i.e. any tool other than MCP buildin tools)
"""
return not tool_types.issubset(MCP_BUILTIN_TOOLS)
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def get_system_message(
model_identity: str | None = None,
reasoning_effort: Literal["high", "medium", "low"] | None = None,
start_date: str | None = None,
browser_description: str | None = None,
python_description: str | None = None,
container_description: str | None = None,
instructions: str | None = None,
with_custom_tools: bool = False,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
current_identity = sys_msg_content.model_identity
new_identity = (
f"{current_identity}\n{instructions}" if current_identity else instructions
)
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort]
)
if start_date is None:
# NOTE(woosuk): This brings non-determinism in vLLM. Be careful.
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
if container_description is not None:
sys_msg_content = sys_msg_content.with_tools(container_description)
if not with_custom_tools:
channel_config = sys_msg_content.channel_config
invalid_channel = "commentary"
new_config = ChannelConfig.require_channels(
[c for c in channel_config.valid_channels if c != invalid_channel]
)
sys_msg_content = sys_msg_content.with_channel_config(new_config)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def create_tool_definition(tool: ChatCompletionToolsParam | Tool):
if isinstance(tool, ChatCompletionToolsParam):
return ToolDescription.new(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
)
return ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
def get_developer_message(
instructions: str | None = None,
tools: list[Tool | ChatCompletionToolsParam] | None = None,
) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools: list[Tool | ChatCompletionToolsParam] = []
for tool in tools:
if tool.type in (
"web_search_preview",
"code_interpreter",
"container",
):
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
create_tool_definition(tool) for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions
)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def parse_response_input(
response_msg: ResponseInputOutputItem,
prev_responses: list[ResponseOutputItem | ResponseReasoningItem],
) -> Message:
if not isinstance(response_msg, dict):
response_msg = response_msg.model_dump()
if "type" not in response_msg or response_msg["type"] == "message":
role = response_msg["role"]
content = response_msg["content"]
if role == "system":
# User is trying to set a system message. Change it to:
# <|start|>developer<|message|># Instructions
# {instructions}<|end|>
role = "developer"
text_prefix = "Instructions:\n"
else:
text_prefix = ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
else:
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
msg = Message.from_role_and_contents(role, contents)
if role == "assistant":
msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: ResponseFunctionToolCall | None = None
for prev_response in reversed(prev_responses):
if (
isinstance(prev_response, ResponseFunctionToolCall)
and prev_response.call_id == call_id
):
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
response_msg["output"],
)
elif response_msg["type"] == "reasoning":
content = response_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif response_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{response_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {response_msg['type']}")
return msg
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
"""
Parse a list of messages from request.messages in the Chat Completion API to
Harmony messages.
"""
msgs: list[Message] = []
tool_id_names: dict[str, str] = {}
# Collect tool id to name mappings for tool response recipient values
for chat_msg in chat_msgs:
for tool_call in chat_msg.get("tool_calls", []):
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
"name"
)
for chat_msg in chat_msgs:
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
msgs = auto_drop_analysis_messages(msgs)
return msgs
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
"""
Harmony models expect the analysis messages (representing raw chain of thought) to
be dropped after an assistant message to the final channel is produced from the
reasoning of those messages.
The openai-harmony library does this if the very last assistant message is to the
final channel, but it does not handle the case where we're in longer multi-turn
conversations and the client gave us reasoning content from previous turns of
the conversation with multiple assistant messages to the final channel in the
conversation.
So, we find the index of the last assistant message to the final channel and drop
all analysis messages that precede it, leaving only the analysis messages that
are relevant to the current part of the conversation.
"""
last_assistant_final_index = -1
for i in range(len(msgs) - 1, -1, -1):
msg = msgs[i]
if msg.author.role == "assistant" and msg.channel == "final":
last_assistant_final_index = i
break
cleaned_msgs: list[Message] = []
for i, msg in enumerate(msgs):
if i < last_assistant_final_index and msg.channel == "analysis":
continue
cleaned_msgs.append(msg)
return cleaned_msgs
def flatten_chat_text_content(content: str | list | None) -> str | None:
"""
Extract the text parts from a chat message content field and flatten them
into a single string.
"""
if isinstance(content, list):
return "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
return content
def parse_chat_input_to_harmony_message(
chat_msg, tool_id_names: dict[str, str] | None = None
) -> list[Message]:
"""
Parse a message from request.messages in the Chat Completion API to
Harmony messages.
"""
tool_id_names = tool_id_names or {}
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
msgs: list[Message] = []
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls", [])
if role == "assistant" and tool_calls:
content = flatten_chat_text_content(chat_msg.get("content"))
if content:
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
commentary_msg = commentary_msg.with_channel("commentary")
msgs.append(commentary_msg)
reasoning_content = chat_msg.get("reasoning") or chat_msg.get(
"reasoning_content"
)
if reasoning_content:
analysis_msg = Message.from_role_and_content(
Role.ASSISTANT, reasoning_content
)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
# Officially, this should be `<|constrain|>json` but there is not clear
# evidence that improves accuracy over `json` and some anecdotes to the
# contrary. Further testing of the different content_types is needed.
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
tool_call_id = chat_msg.get("tool_call_id", "")
name = tool_id_names.get(tool_call_id, "")
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)
msg = (
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
)
.with_channel("commentary")
.with_recipient("assistant")
)
return [msg]
# Non-tool reasoning content
reasoning_content = chat_msg.get("reasoning") or chat_msg.get("reasoning_content")
if role == "assistant" and reasoning_content:
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_content)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
# Default: user/assistant/system messages with content
content = chat_msg.get("content") or ""
if content is None:
content = ""
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
# Only add assistant messages if they have content, as reasoning or tool calling
# assistant messages were already added above.
if role == "assistant" and contents and contents[0].text:
msg = Message.from_role_and_contents(role, contents)
# Send non-tool assistant messages to the final channel
msg = msg.with_channel("final")
msgs.append(msg)
# For user/system/developer messages, add them directly even if no content.
elif role != "assistant":
msg = Message.from_role_and_contents(role, contents)
msgs.append(msg)
return msgs
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
"""
Parse a message from request.previous_input_messages in the Responsees API to
Harmony messages.
"""
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls")
if role == "assistant" and tool_calls:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
).with_channel("commentary")
return [msg]
# Default: user/assistant/system messages with content
content = chat_msg.get("content", "")
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
msg = Message.from_role_and_contents(role, contents)
return [msg]
def construct_harmony_previous_input_messages(
request: ResponsesRequest,
) -> list[OpenAIHarmonyMessage]:
messages: list[OpenAIHarmonyMessage] = []
if request.previous_input_messages:
for message in request.previous_input_messages:
# Handle both OpenAIHarmonyMessage objects and dictionary inputs
if isinstance(message, OpenAIHarmonyMessage):
message_role = message.author.role
# To match OpenAI, instructions, reasoning and tools are
# always taken from the most recent Responses API request
# not carried over from previous requests
if (
message_role == OpenAIHarmonyRole.SYSTEM
or message_role == OpenAIHarmonyRole.DEVELOPER
):
continue
messages.append(message)
else:
harmony_messages = parse_input_to_harmony_message(message)
for harmony_msg in harmony_messages:
message_role = harmony_msg.author.role
# To match OpenAI, instructions, reasoning and tools are
# always taken from the most recent Responses API request
# not carried over from previous requests
if (
message_role == OpenAIHarmonyRole.SYSTEM
or message_role == OpenAIHarmonyRole.DEVELOPER
):
continue
messages.append(harmony_msg)
return messages
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT
)
return token_ids
def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem:
"""Parse browser tool calls (search, open, find) into web search items."""
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
# Parse JSON args (with retry detection)
try:
browser_call = json.loads(content.text)
except json.JSONDecodeError:
json_retry_output_message = (
f"Invalid JSON args, caught and retried: {content.text}"
)
browser_call = {
"query": json_retry_output_message,
"url": json_retry_output_message,
"pattern": json_retry_output_message,
}
# Create appropriate action based on recipient
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search"
)
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
)
elif recipient == "browser.find":
action = ActionFind(
pattern=browser_call.get("pattern", ""),
url=f"cursor:{browser_call.get('url', '')}",
type="find",
)
else:
raise ValueError(f"Unknown browser action: {recipient}")
return ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
"""Parse function calls into function tool call items."""
function_name = recipient.split(".")[-1]
output_items = []
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
return output_items
def _parse_reasoning_content(message: Message) -> list[ResponseOutputItem]:
"""Parse reasoning/analysis content into reasoning items."""
output_items = []
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(text=content.text, type="reasoning_text")
],
status=None,
)
output_items.append(reasoning_item)
return output_items
def _parse_final_message(message: Message) -> ResponseOutputItem:
"""Parse final channel messages into output message items."""
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
return ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
def _parse_mcp_recipient(recipient: str) -> tuple[str, str]:
"""
Parse MCP recipient into (server_label, tool_name).
For dotted recipients like "repo_browser.list":
- server_label: "repo_browser" (namespace/server)
- tool_name: "list" (specific tool)
For simple recipients like "filesystem":
- server_label: "filesystem"
- tool_name: "filesystem"
"""
if "." in recipient:
server_label = recipient.split(".")[0]
tool_name = recipient.split(".")[-1]
else:
server_label = recipient
tool_name = recipient
return server_label, tool_name
def _parse_mcp_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
"""Parse MCP calls into MCP call items."""
server_label, tool_name = _parse_mcp_recipient(recipient)
output_items = []
for content in message.content:
response_item = McpCall(
arguments=content.text,
type="mcp_call",
name=tool_name,
server_label=server_label,
id=f"mcp_{random_uuid()}",
status="completed",
)
output_items.append(response_item)
return output_items
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
"""
Parse a Harmony message into a list of output response items.
"""
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items: list[ResponseOutputItem] = []
recipient = message.recipient
if recipient is not None:
# Browser tool calls
if recipient.startswith("browser."):
output_items.append(_parse_browser_tool_call(message, recipient))
# Function calls (should only happen on commentary channel)
elif message.channel == "commentary" and recipient.startswith("functions."):
output_items.extend(_parse_function_call(message, recipient))
# Built-in tools are treated as reasoning
elif recipient.startswith(("python", "browser", "container")):
# Built-in tool recipients (python/browser/container)
# generate reasoning output
output_items.extend(_parse_reasoning_content(message))
# All other recipients are MCP calls
else:
output_items.extend(_parse_mcp_call(message, recipient))
# No recipient - handle based on channel for non-tool messages
elif message.channel == "analysis":
output_items.extend(_parse_reasoning_content(message))
elif message.channel == "commentary":
# Per Harmony format, commentary channel can contain preambles to calling
# multiple functions - explanatory text with no recipient
output_items.extend(_parse_reasoning_content(message))
elif message.channel == "final":
output_items.append(_parse_final_message(message))
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items
def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
if not parser.current_content:
return []
if parser.current_role != Role.ASSISTANT:
return []
current_recipient = parser.current_recipient
if current_recipient is not None and current_recipient.startswith("browser."):
return []
if current_recipient and parser.current_channel in ("commentary", "analysis"):
if current_recipient.startswith("functions."):
rid = random_uuid()
return [
ResponseFunctionToolCall(
arguments=parser.current_content,
call_id=f"call_{rid}",
type="function_call",
name=current_recipient.split(".")[-1],
id=f"fc_{rid}",
status="in_progress",
)
]
# Built-in tools (python, browser, container) should be treated as reasoning
elif not (
current_recipient.startswith("python")
or current_recipient.startswith("browser")
or current_recipient.startswith("container")
):
# All other recipients are MCP calls
rid = random_uuid()
server_label, tool_name = _parse_mcp_recipient(current_recipient)
return [
McpCall(
arguments=parser.current_content,
type="mcp_call",
name=tool_name,
server_label=server_label,
id=f"mcp_{rid}",
status="in_progress",
)
]
if parser.current_channel == "commentary":
return [
ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
]
if parser.current_channel == "analysis":
return [
ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
]
if parser.current_channel == "final":
output_text = ResponseOutputText(
text=parser.current_content,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
# if the parser still has messages (ie if the generator got cut
# abruptly), this should be incomplete
status="incomplete",
type="message",
)
return [text_item]
return []
def get_stop_tokens_for_assistant_actions() -> list[int]:
return get_encoding().stop_tokens_for_assistant_actions()
def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
parser = get_streamable_parser_for_assistant()
for token_id in token_ids:
parser.process(token_id)
return parser
def parse_chat_output(
token_ids: Sequence[int],
) -> tuple[str | None, str | None, bool]:
"""
Parse the output of a Harmony chat completion into reasoning and final content.
Note that when the `openai` tool parser is used, serving_chat only uses this
for the reasoning content and gets the final content from the tool call parser.
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
in use,this needs to return the final content without any tool calls parsed.
Empty reasoning or final content is returned as None instead of an empty string.
"""
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported
# Get completed messages from the parser
reasoning_texts = [
msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
]
final_texts = [
msg.content[0].text for msg in output_msgs if msg.channel != "analysis"
]
# Extract partial messages from the parser
if parser.current_channel == "analysis" and parser.current_content:
reasoning_texts.append(parser.current_content)
elif parser.current_channel != "analysis" and parser.current_content:
final_texts.append(parser.current_content)
# Flatten multiple messages into a single string
reasoning: str | None = "\n".join(reasoning_texts)
final_content: str | None = "\n".join(final_texts)
# Return None instead of empty string since existing callers check for None
reasoning = reasoning or None
final_content = final_content or None
return reasoning, final_content, is_tool_call

View File

@@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections.abc import Callable
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
Content,
ResponseReasoningItem,
)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = logging.getLogger(__name__)
class ResponsesParser:
"""Incremental parser over completion tokens with reasoning support."""
def __init__(
self,
*,
tokenizer: AnyTokenizer,
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
response_messages
)
self.num_init_messages = len(response_messages)
self.tokenizer = tokenizer
self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)
def process(self, output: CompletionOutput) -> "ResponsesParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
output.text, request=self.request
)
if reasoning_content:
self.response_messages.append(
ResponseReasoningItem(
type="reasoning",
id=f"rs_{random_uuid()}",
summary=[],
content=[
Content(
type="reasoning_text",
text=reasoning_content,
)
],
)
)
function_calls: list[ResponseFunctionToolCall] = []
if self.tool_parser_instance is not None:
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None
if content:
self.response_messages.append(
ResponseOutputMessage(
type="message",
id=f"msg_{random_uuid()}",
status="completed",
role="assistant",
content=[
ResponseOutputText(
annotations=[], # TODO
type="output_text",
text=content,
logprobs=None, # TODO
)
],
)
)
if len(function_calls) > 0:
self.response_messages.extend(function_calls)
return self
def get_responses_parser_for_simple_context(
*,
tokenizer: AnyTokenizer,
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
) -> ResponsesParser:
"""Factory function to create a ResponsesParser with
optional reasoning parser.
Returns:
ResponsesParser instance configured with the provided parser
"""
return ResponsesParser(
tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,631 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import tempfile
from argparse import Namespace
from collections.abc import Awaitable, Callable
from http import HTTPStatus
from io import StringIO
from typing import Any, TypeAlias
import aiohttp
import torch
from prometheus_client import start_http_server
from pydantic import TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
OpenAIBaseModel,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
BatchRequestInputBody: TypeAlias = (
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameters of the request.
body: BatchRequestInputBody
@field_validator("body", mode="plain")
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url: str = info.data["url"]
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url.endswith("/score"):
return ScoreRequest.model_validate(value)
if url.endswith("/rerank"):
return RerankRequest.model_validate(value)
return TypeAdapter(BatchRequestInputBody).validate_python(value)
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
# An unique identifier for the API request.
request_id: str
# The body of the response.
body: (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| None
) = None
class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: BatchResponseData | None
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Any | None
def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help="The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.",
)
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.",
)
parser.add_argument(
"--output-tmp-dir",
type=str,
default=None,
help="The directory to store the output file before uploading it "
"to the output URL.",
)
parser.add_argument(
"--response-role",
type=optional_type(str),
default="assistant",
help="The role name to return if `request.add_generation_prompt=True`.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument(
"--max-log-len",
type=int,
default=None,
help="Max number of prompt characters or prompt "
"ID numbers being printed in log."
"\n\nDefault: Unlimited",
)
parser.add_argument(
"--enable-metrics", action="store_true", help="Enable Prometheus metrics"
)
parser.add_argument(
"--url",
type=str,
default="0.0.0.0",
help="URL to the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--enable-prompt-tokens-details",
action="store_true",
default=False,
help="If set to True, enable prompt_tokens_details in usage.",
)
parser.add_argument(
"--enable-force-include-usage",
action="store_true",
default=False,
help="If set to True, include usage on every request "
"(even when stream_options is not specified)",
)
return parser
def parse_args():
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
return make_arg_parser(parser).parse_args()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: tqdm | None = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
self._pbar = tqdm(
total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, encoding="utf-8") as f:
return f.read()
async def write_local_file(
output_path: str, batch_outputs: list[BatchRequestOutput]
) -> None:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't affect performance.
with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs:
print(o.model_dump_json(), file=f)
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries = 5
# Number of seconds to wait before retrying.
delay = 5
for attempt in range(1, max_retries + 1):
try:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=1000)
) as session:
if from_file:
with open(data_or_file, "rb") as file:
async with session.put(output_url, data=file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload file.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
else:
async with session.put(output_url, data=data_or_file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload data.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
except Exception as e:
if attempt < max_retries:
logger.error(
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
attempt,
e,
delay,
)
await asyncio.sleep(delay)
else:
raise Exception(
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(
path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str
) -> None:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
if output_tmp_dir is None:
logger.info("Writing outputs to memory buffer")
output_buffer = StringIO()
for o in batch_outputs:
print(o.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(
path_or_url,
output_buffer.read().strip().encode("utf-8"),
from_file=False,
)
else:
# Write responses to a temporary file and then upload it to the URL.
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=output_tmp_dir,
prefix="tmp_batch_output_",
suffix=".jsonl",
) as f:
logger.info("Writing outputs to temporary local file %s", f.name)
await write_local_file(f.name, batch_outputs)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(path_or_url, f.name, from_file=True)
else:
logger.info("Writing outputs to local file %s", path_or_url)
await write_local_file(path_or_url, batch_outputs)
def make_error_request_output(
request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=HTTPStatus.BAD_REQUEST,
request_id=f"vllm-batch-{random_uuid()}",
),
error=error_msg,
)
return batch_output
async def make_async_error_request_output(
request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
return make_error_request_output(request, error_msg)
async def run_request(
serving_engine_func: Callable,
request: BatchRequestInput,
tracker: BatchProgressTracker,
) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(
response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse),
):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=response, request_id=f"vllm-batch-{random_uuid()}"
),
error=None,
)
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=response.error.code,
request_id=f"vllm-batch-{random_uuid()}",
),
error=response,
)
else:
batch_output = make_error_request_output(
request, error_msg="Request must not be sent in stream mode"
)
tracker.completed()
return batch_output
def validate_run_batch_args(args):
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
async def run_batch(
engine_client: EngineClient,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
model_config = engine_client.model_config
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
base_model_paths=base_model_paths,
lora_modules=None,
)
openai_serving_chat = (
OpenAIServingChat(
engine_client,
openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None
)
openai_serving_embedding = (
OpenAIServingEmbedding(
engine_client,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
)
if "embed" in supported_tasks
else None
)
enable_serving_reranking = (
"classify" in supported_tasks
and getattr(model_config.hf_config, "num_labels", 0) == 1
)
openai_serving_scores = (
ServingScores(
engine_client,
openai_serving_models,
request_logger=request_logger,
)
if ("embed" in supported_tasks or enable_serving_reranking)
else None
)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: list[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
chat_handler_fn = (
openai_serving_chat.create_chat_completion
if openai_serving_chat is not None
else None
)
if chat_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Chat Completions API",
)
)
continue
response_futures.append(run_request(chat_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
embed_handler_fn = (
openai_serving_embedding.create_embedding
if openai_serving_embedding is not None
else None
)
if embed_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Embeddings API",
)
)
continue
response_futures.append(run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/score"):
score_handler_fn = (
openai_serving_scores.create_score
if openai_serving_scores is not None
else None
)
if score_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Scores API",
)
)
continue
response_futures.append(run_request(score_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/rerank"):
rerank_handler_fn = (
openai_serving_scores.do_rerank
if openai_serving_scores is not None
else None
)
if rerank_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Rerank API",
)
)
continue
response_futures.append(run_request(rerank_handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /score, /rerank ."
"See vllm/entrypoints/openai/api_server.py for supported "
"score/rerank versions.",
)
)
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
await write_file(args.output_file, responses, args.output_tmp_dir)
async def main(args: Namespace):
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.usage.usage_lib import UsageContext
validate_run_batch_args(args)
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
await run_batch(engine_client, args)
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
logger.info("args: %s", args)
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(main(args))

File diff suppressed because it is too large Load Diff

View File

@@ -1,67 +1,90 @@
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import cast
import jinja2
from fastapi import Request
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.utils import merge_async_iterators, random_uuid
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
return prompt_is_tokens, prompts
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
# set up logits processors
self.logits_processors = self.model_config.logits_processors
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info(
"Using default completion sampling params from %s: %s",
source,
self.default_sampling_params,
)
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
@@ -75,90 +98,214 @@ class OpenAIServingCompletion(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
return self.create_error_response("suffix is not currently supported")
model_name = self.served_model_names[0]
request_id = f"cmpl-{random_uuid()}"
if request.echo and request.prompt_embeds is not None:
return self.create_error_response("Echo is unsupported with prompt embeds.")
if request.prompt_logprobs is not None and request.prompt_embeds is not None:
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds."
)
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
created_time = int(time.time())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
lora_request = self._maybe_get_adapters(request)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
raise NotImplementedError
generators.append(
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids,
lora_request=lora_request))
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generator = self.engine_client.generate(
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
result_generator = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# We do not stream the results when using beam search.
stream = request.stream and not request.use_beam_search
# Streaming response
if stream:
return self.completion_stream_generator(request,
raw_request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts))
return self.completion_stream_generator(
request,
engine_prompts,
result_generator,
request_id,
created_time,
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
final_res_batch: list[RequestOutput | None] = [None] * num_prompts
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name)
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@@ -179,80 +326,126 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
engine_prompts: list[TokensPrompt | EmbedsPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options
include_usage, include_continuous_usage = should_include_usage(
stream_options, self.enable_force_include_usage
)
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
i = output.index + prompt_idx * num_choices
# Useful when request.return_token_ids is True
# Returning prompt token IDs shares the same logic
# with the echo implementation.
prompt_token_ids_to_return: list[int] | None = None
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs
or [])
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
else:
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids,
*output.token_ids,
]
out_logprobs = [
*(prompt_logprobs or []),
*(output.logprobs or []),
]
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
# has_echoed[i] is reused here to indicate whether
# we have already returned the prompt token IDs.
if not has_echoed[i] and request.return_token_ids:
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
if (
not delta_text
and not delta_token_ids
and not previous_num_tokens[i]
):
# Chunked prefill case, don't return empty chunks
continue
if request.logprobs is not None:
logprobs = self._create_logprobs(
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
tokenizer=tokenizer,
initial_text_offset=previous_text_lens[i],
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
if output.finish_reason is not None: # return final usage
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
final_usage = None
response_json = CompletionStreamResponse(
self._raise_if_error(finish_reason, request_id)
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
@@ -263,58 +456,129 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
prompt_token_ids=prompt_token_ids_to_return,
token_ids=(
as_list(output.token_ids)
if request.return_token_ids
else None
),
)
],
usage=final_usage,
).model_dump_json(exclude_unset=True)
)
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
except ValueError as e:
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens
)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=final_usage_info,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: List[RequestOutput],
final_res_batch: list[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch:
assert final_res is not None
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text
if request.echo:
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
else:
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
out_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
@@ -325,12 +589,19 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
prompt_token_ids=(
prompt_token_ids if request.return_token_ids else None
),
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
@@ -338,10 +609,121 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if (
self.enable_prompt_tokens_details
and last_final_res
and last_final_res.num_cached_tokens
):
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=last_final_res.num_cached_tokens
)
request_metadata.final_usage_info = usage
if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=kv_transfer_params,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int,
tokenizer: TokenizerLike | None,
initial_text_offset: int = 0,
return_as_token_id: bool | None = None,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: list[int] = []
out_token_logprobs: list[float | None] = []
out_tokens: list[str] = []
out_top_logprobs: list[dict[str, float] | None] = []
last_token_len = 0
should_return_as_token_id = (
return_as_token_id
if return_as_token_id is not None
else self.return_tokens_as_token_ids
)
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
if should_return_as_token_id:
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_token,
token_id,
tokenizer,
return_as_token_id=should_return_as_token_id,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append(
{
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
): max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
}
)
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,304 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from asyncio import Lock
from collections import defaultdict
from dataclasses import dataclass
from http import HTTPStatus
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (
ErrorInfo,
ErrorResponse,
LoadLoRAAdapterRequest,
ModelCard,
ModelList,
ModelPermission,
UnloadLoRAAdapterRequest,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.utils.counter import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: str | None = None
class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.
Handles the routes:
- /v1/models
- /v1/load_lora_adapter
- /v1/unload_lora_adapter
"""
def __init__(
self,
engine_client: EngineClient,
base_model_paths: list[BaseModelPath],
*,
lora_modules: list[LoRAModulePath] | None = None,
):
super().__init__()
self.engine_client = engine_client
self.base_model_paths = base_model_paths
self.static_lora_modules = lora_modules
self.lora_requests: dict[str, LoRARequest] = {}
self.lora_id_counter = AtomicCounter(0)
self.lora_resolvers: list[LoRAResolver] = []
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers():
self.lora_resolvers.append(
LoRAResolverRegistry.get_resolver(lora_resolver_name)
)
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.input_processor = self.engine_client.input_processor
self.io_processor = self.engine_client.io_processor
self.model_config = self.engine_client.model_config
self.max_model_len = self.model_config.max_model_len
async def init_static_loras(self):
"""Loads all static LoRA modules.
Raises if any fail to load"""
if self.static_lora_modules is None:
return
for lora in self.static_lora_modules:
load_request = LoadLoRAAdapterRequest(
lora_path=lora.path, lora_name=lora.name
)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name
)
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
def model_name(self, lora_request: LoRARequest | None = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all
adapters"""
model_cards = [
ModelCard(
id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(
id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name
if lora.base_model_name
else self.base_model_paths[0].name,
permission=[ModelPermission()],
)
for lora in self.lora_requests.values()
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
async def load_lora_adapter(
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
) -> ErrorResponse | str:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_path = request.lora_path
unique_id = self.lora_id_counter.inc(1)
lora_request = LoRARequest(
lora_name=lora_name, lora_int_id=unique_id, lora_path=lora_path
)
if base_model_name is not None and self.is_base_model(base_model_name):
lora_request.base_model_name = base_model_name
# Validate that the adapter can be loaded into the engine
# This will also preload it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except Exception as e:
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
return create_error_response(
message=str(e), err_type=error_type, status_code=status_code
)
self.lora_requests[lora_name] = lora_request
logger.info(
"Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path
)
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self, request: UnloadLoRAAdapterRequest
) -> ErrorResponse | str:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_unload_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
# Safe to delete now since we hold the lock
del self.lora_requests[lora_name]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoRAAdapterRequest
) -> ErrorResponse | None:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
# Check if the lora adapter with the given name already exists
if request.lora_name in self.lora_requests:
return create_error_response(
message=f"The lora adapter '{request.lora_name}' has already been "
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
async def _check_unload_lora_adapter_request(
self, request: UnloadLoRAAdapterRequest
) -> ErrorResponse | None:
# Check if 'lora_name' is not provided return an error
if not request.lora_name:
return create_error_response(
message="'lora_name' needs to be provided to unload a LoRA adapter.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
# Check if the lora adapter with the given name exists
if request.lora_name not in self.lora_requests:
return create_error_response(
message=f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
)
return None
async def resolve_lora(self, lora_name: str) -> LoRARequest | ErrorResponse:
"""Attempt to resolve a LoRA adapter using available resolvers.
Args:
lora_name: Name/identifier of the LoRA adapter
Returns:
LoRARequest if found and loaded successfully.
ErrorResponse (404) if no resolver finds the adapter.
ErrorResponse (400) if adapter(s) are found but none load.
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
if lora_name in self.lora_requests:
return self.lora_requests[lora_name]
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
found_adapter = False
# Try to resolve using available resolvers
for resolver in self.lora_resolvers:
lora_request = await resolver.resolve_lora(base_model_name, lora_name)
if lora_request is not None:
found_adapter = True
lora_request.lora_int_id = unique_id
try:
await self.engine_client.add_lora(lora_request)
self.lora_requests[lora_name] = lora_request
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name,
resolver.__class__.__name__,
)
return lora_request
except BaseException as e:
logger.warning(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver.",
lora_name,
resolver.__class__.__name__,
e,
)
continue
if found_adapter:
# An adapter was found, but all attempts to load it failed.
return create_error_response(
message=(
f"LoRA adapter '{lora_name}' was found but could not be loaded."
),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST,
)
else:
# No adapter was found
return create_error_response(
message=f"LoRA adapter {lora_name} does not exist",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
)
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
return ErrorResponse(
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import AsyncGenerator
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
RequestResponseMetadata,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionResponseVerbose,
TranscriptionStreamResponse,
TranslationRequest,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationStreamResponse,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
logger = init_logger(__name__)
class OpenAIServingTranscription(OpenAISpeechToText):
"""Handles transcription requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
) -> (
TranscriptionResponse
| TranscriptionResponseVerbose
| AsyncGenerator[str, None]
| ErrorResponse
):
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=(
TranscriptionResponseVerbose
if request.response_format == "verbose_json"
else TranscriptionResponse
),
stream_generator_method=self.transcription_stream_generator,
)
async def transcription_stream_generator(
self,
request: TranscriptionRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="transcription.chunk",
response_stream_choice_class=TranscriptionResponseStreamChoice,
stream_response_class=TranscriptionStreamResponse,
)
async for chunk in generator:
yield chunk
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_translation(
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
) -> (
TranslationResponse
| TranslationResponseVerbose
| AsyncGenerator[str, None]
| ErrorResponse
):
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=(
TranslationResponseVerbose
if request.response_format == "verbose_json"
else TranslationResponse
),
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self,
request: TranslationRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
)
async for chunk in generator:
yield chunk

View File

@@ -0,0 +1,559 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
from collections.abc import AsyncGenerator, Callable
from functools import cached_property
from typing import Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
DeltaMessage,
ErrorResponse,
RequestResponseMetadata,
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionResponseVerbose,
TranscriptionSegment,
TranscriptionStreamResponse,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationSegment,
TranslationStreamResponse,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsTranscription
from vllm.outputs import RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
)
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
T = TypeVar("T", bound=SpeechToTextResponse)
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
S = TypeVar("S", bound=SpeechToTextSegment)
ResponseType: TypeAlias = (
TranscriptionResponse
| TranslationResponse
| TranscriptionResponseVerbose
| TranslationResponseVerbose
)
logger = init_logger(__name__)
class OpenAISpeechToText(OpenAIServing):
"""Base class for speech-to-text operations like transcription and
translation."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
task_type: Literal["transcribe", "translate"] = "transcribe",
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type
)
self.enable_force_include_usage = enable_force_include_usage
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
if self.model_cls.supports_segment_timestamp:
self.tokenizer = cast(
PreTrainedTokenizerBase,
get_tokenizer(
tokenizer_name=self.model_config.tokenizer,
tokenizer_mode=self.model_config.tokenizer_mode,
),
)
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params,
)
@cached_property
def model_cls(self) -> type[SupportsTranscription]:
from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsTranscription], model_cls)
async def _preprocess_speech_to_text(
self,
request: SpeechToTextRequest,
audio_data: bytes,
) -> tuple[list[PromptType], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
to_language = (
self.model_cls.validate_language(request.to_language)
if request.to_language
else None
)
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (
self.asr_config.allow_audio_chunking
and duration > self.asr_config.max_audio_clip_s
)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
for chunk in chunks:
# The model has control over the construction, as long as it
# returns a valid PromptType.
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
)
if request.response_format == "verbose_json":
if not isinstance(prompt, dict):
raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}")
prompt_dict = cast(dict, prompt)
decoder_prompt = prompt.get("decoder_prompt")
if not isinstance(decoder_prompt, str):
raise ValueError(
f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}"
)
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
"<|notimestamps|>", "<|0.00|>"
)
prompts.append(prompt)
return prompts, duration
def _get_verbose_segments(
self,
tokens: tuple,
request: SpeechToTextRequest,
segment_class: type[SpeechToTextSegment],
start_time: float = 0,
) -> list[SpeechToTextSegment]:
"""
Convert tokens to verbose segments.
This method expects the model to produce
timestamps as tokens (similar to Whisper).
If the tokens do not include timestamp information,
the segments may not be generated correctly.
Note: Fields like avg_logprob, compression_ratio,
and no_speech_prob are not supported
in this implementation and will be None. See docs for details.
"""
BASE_OFFSET = 0.02
init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
if tokens[-1] == self.tokenizer.eos_token_id:
tokens = tokens[:-1]
tokens_with_start = (init_token,) + tokens
segments: list[SpeechToTextSegment] = []
last_timestamp_start = 0
if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
for idx, token in enumerate(tokens_with_start):
# Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
# If the ordering is violated, this slicing may produce incorrect results.
if (
token >= init_token
and idx != 0
and tokens_with_start[idx - 1] >= init_token
):
sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
start_timestamp = sliced_timestamp_tokens[0] - init_token
end_timestamp = sliced_timestamp_tokens[-1] - init_token
casting_segment = cast(
SpeechToTextSegment,
segment_class(
id=len(segments),
seek=start_time,
start=start_time + BASE_OFFSET * start_timestamp,
end=start_time + BASE_OFFSET * end_timestamp,
temperature=request.temperature,
text=self.tokenizer.decode(sliced_timestamp_tokens[1:-1]),
tokens=sliced_timestamp_tokens[1:-1],
),
)
segments.append(casting_segment)
last_timestamp_start = idx
return segments
async def _create_speech_to_text(
self,
audio_data: bytes,
request: SpeechToTextRequest,
raw_request: Request,
response_class: type[T | V],
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
"""Base method for speech-to-text operations like transcription and
translation."""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
if request.response_format not in ["text", "json", "verbose_json"]:
return self.create_error_response(
("Currently only support response_format")
+ ("`text`, `json` or `verbose_json`")
)
if (
request.response_format == "verbose_json"
and not self.model_cls.supports_segment_timestamp
):
return self.create_error_response(
f"Currently do not support verbose_json for {request.model}"
)
if request.response_format == "verbose_json" and request.stream:
return self.create_error_response(
"verbose_json format doesn't support streaming case"
)
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
self._log_inputs(
request_id,
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=lora_request,
)
list_result_generator = [
self.engine_client.generate(
prompt,
sampling_params,
f"{request_id}_{i}",
lora_request=lora_request,
)
for i, prompt in enumerate(prompts)
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if request.stream:
return stream_generator_method(
request, list_result_generator, request_id, request_metadata, duration_s
)
# Non-streaming response.
total_segments = []
text_parts = []
try:
assert list_result_generator is not None
segments_types: dict[str, type[SpeechToTextSegment]] = {
"transcribe": TranscriptionSegment,
"translate": TranslationSegment,
}
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
text = ""
for idx, result_generator in enumerate(list_result_generator):
async for op in result_generator:
if request.response_format == "verbose_json":
segments: list[SpeechToTextSegment] = (
self._get_verbose_segments(
tokens=tuple(op.outputs[0].token_ids),
segment_class=segment_class,
request=request,
start_time=idx * self.asr_config.max_audio_clip_s,
)
)
total_segments.extend(segments)
text_parts.extend([seg.text for seg in segments])
else:
text_parts.append(op.outputs[0].text)
text = "".join(text_parts)
if self.task_type == "transcribe":
final_response: ResponseType
# add usage in TranscriptionResponse.
usage = {
"type": "duration",
# rounded up as per openAI specs
"seconds": int(math.ceil(duration_s)),
}
if request.response_format != "verbose_json":
final_response = cast(
T, TranscriptionResponse(text=text, usage=usage)
)
else:
final_response = cast(
V,
TranscriptionResponseVerbose(
text=text,
language=request.language,
duration=str(duration_s),
segments=total_segments,
),
)
else:
# no usage in response for translation task
if request.response_format != "verbose_json":
final_response = cast(T, TranslationResponse(text=text))
else:
final_response = cast(
V,
TranslationResponseVerbose(
text=text,
language=request.language,
duration=str(duration_s),
segments=total_segments,
),
)
return final_response
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def _speech_to_text_stream_generator(
self,
request: SpeechToTextRequest,
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
response_stream_choice_class: type[TranscriptionResponseStreamChoice]
| type[TranslationResponseStreamChoice],
stream_response_class: type[TranscriptionStreamResponse]
| type[TranslationStreamResponse],
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
model_name = request.model
completion_tokens = 0
num_prompt_tokens = 0
include_usage = self.enable_force_include_usage or request.stream_include_usage
include_continuous_usage = (
request.stream_continuous_usage_stats
if include_usage and request.stream_continuous_usage_stats
else False
)
try:
for result_generator in list_result_generator:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if audio_tokens := self.model_cls.get_num_audio_tokens(
audio_duration_s, self.asr_config, self.model_config
):
num_prompt_tokens += audio_tokens
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]
delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = response_stream_choice_class(delta=delta_message)
else:
# Model is finished generating.
choice_data = response_stream_choice_class(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name,
)
# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
if include_usage:
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
final_usage_chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
def _split_audio(
self, audio_data: np.ndarray, sample_rate: int
) -> list[np.ndarray]:
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
chunks = []
i = 0
while i < audio_data.shape[-1]:
if i + chunk_size >= audio_data.shape[-1]:
# handle last chunk
chunks.append(audio_data[..., i:])
break
# Find the best split point in the overlap region
search_start = i + chunk_size - overlap_size
search_end = min(i + chunk_size, audio_data.shape[-1])
split_point = self._find_split_point(audio_data, search_start, search_end)
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def _find_split_point(self, wav: np.ndarray, start_idx: int, end_idx: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
min_energy_window = self.asr_config.min_energy_split_window_size
assert min_energy_window is not None
for i in range(0, len(segment) - min_energy_window, min_energy_window):
window = segment[i : i + min_energy_window]
energy = (window**2).mean() ** 0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
def __getattr__(name: str):
if name == "ToolParser":
from vllm.tool_parsers import ToolParser
warnings.warn(
"`vllm.entrypoints.openai.tool_parsers.ToolParser` has been moved to "
"`vllm.tool_parsers.ToolParser`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
return ToolParser
if name == "ToolParserManager":
from vllm.tool_parsers import ToolParserManager
warnings.warn(
"`vllm.entrypoints.openai.tool_parsers.ToolParserManager` "
"has been moved to `vllm.tool_parsers.ToolParserManager`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
return ToolParserManager
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeVar
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
)
# Used internally
_ChatCompletionResponseChoiceT = TypeVar(
"_ChatCompletionResponseChoiceT",
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
)
def maybe_filter_parallel_tool_calls(
choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
) -> _ChatCompletionResponseChoiceT:
"""Filter to first tool call only when parallel_tool_calls is False."""
if request.parallel_tool_calls:
return choice
if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
choice.message.tool_calls = choice.message.tool_calls[:1]
elif (
isinstance(choice, ChatCompletionResponseStreamChoice)
and choice.delta.tool_calls
):
choice.delta.tool_calls = [
tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
]
return choice
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise RequestValidationError(
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
)