This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

View File

@@ -0,0 +1,545 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import inspect
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
import signal
import socket
import tempfile
import warnings
from argparse import Namespace
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
import uvloop
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from starlette.datastructures import State
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
get_uvicorn_log_config,
http_exception_handler,
lifespan,
log_response,
validation_exception_handler,
)
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
cli_env_setup,
log_non_default_args,
log_version_and_model,
process_lora_modules,
)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager
from vllm.tracing import instrument
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger("vllm.entrypoints.openai.api_server")
_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",)
@asynccontextmanager
async def build_async_engine_client(
args: Namespace,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: bool | None = None,
client_config: dict[str, Any] | None = None,
) -> AsyncIterator[EngineClient]:
if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
# The executor is expected to be mp.
# Pre-import heavy modules in the forkserver process
logger.debug("Setup forkserver with pre-imports")
multiprocessing.set_start_method("forkserver")
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
forkserver.ensure_running()
logger.debug("Forkserver setup complete!")
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
if client_config:
engine_args._api_process_count = client_config.get("client_count", 1)
engine_args._api_process_rank = client_config.get("client_index", 0)
if disable_frontend_multiprocessing is None:
disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
async with build_async_engine_client_from_engine_args(
engine_args,
usage_context=usage_context,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
client_config=client_config,
) as engine:
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: bool = False,
client_config: dict[str, Any] | None = None,
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Create the EngineConfig (determines if we can use V1).
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if disable_frontend_multiprocessing:
logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: AsyncLLM | None = None
# Don't mutate the input client_config
client_config = dict(client_config) if client_config else {}
client_count = client_config.pop("client_count", 1)
client_index = client_config.pop("client_index", 0)
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
enable_log_requests=engine_args.enable_log_requests,
aggregate_engine_logging=engine_args.aggregate_engine_logging,
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_count=client_count,
client_index=client_index,
)
# Don't keep the dummy data in memory
assert async_llm is not None
await async_llm.reset_mm_cache()
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
def build_app(
args: Namespace, supported_tasks: tuple["SupportedTask", ...] | None = None
) -> FastAPI:
if supported_tasks is None:
warnings.warn(
"The 'supported_tasks' parameter was not provided to "
"build_app and will be required in a future version. "
"Defaulting to ('generate',).",
DeprecationWarning,
stacklevel=2,
)
supported_tasks = _FALLBACK_SUPPORTED_TASKS
if args.disable_fastapi_docs:
app = FastAPI(
openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
)
elif args.enable_offline_docs:
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.state.args = args
from vllm.entrypoints.serve import register_vllm_serve_api_routers
register_vllm_serve_api_routers(app)
from vllm.entrypoints.openai.models.api_router import (
attach_router as register_models_api_router,
)
register_models_api_router(app)
from vllm.entrypoints.sagemaker.api_router import (
attach_router as register_sagemaker_api_router,
)
register_sagemaker_api_router(app, supported_tasks)
if "generate" in supported_tasks:
from vllm.entrypoints.openai.generate.api_router import (
register_generate_api_routers,
)
register_generate_api_routers(app)
from vllm.entrypoints.serve.disagg.api_router import (
attach_router as attach_disagg_router,
)
attach_disagg_router(app)
from vllm.entrypoints.serve.rlhf.api_router import (
attach_router as attach_rlhf_router,
)
attach_rlhf_router(app)
from vllm.entrypoints.serve.elastic_ep.api_router import (
attach_router as elastic_ep_attach_router,
)
elastic_ep_attach_router(app)
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.speech_to_text.api_router import (
attach_router as register_speech_to_text_api_router,
)
register_speech_to_text_api_router(app)
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import (
attach_router as register_realtime_api_router,
)
register_realtime_api_router(app)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import register_pooling_api_routers
register_pooling_api_routers(app, supported_tasks)
app.root_path = args.root_path
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
app.exception_handler(HTTPException)(http_exception_handler)
app.exception_handler(RequestValidationError)(validation_exception_handler)
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware
app.add_middleware(AuthenticationMiddleware, tokens=tokens)
if args.enable_request_id_headers:
from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware
app.add_middleware(XRequestIdMiddleware)
# Add scaling middleware to check for scaling state
app.add_middleware(ScalingMiddleware)
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning(
"CAUTION: Enabling log response in the API Server. "
"This can include sensitive information and should be "
"avoided in production."
)
app.middleware("http")(log_response)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)
app = sagemaker_standards_bootstrap(app)
return app
async def init_app_state(
engine_client: EngineClient,
state: State,
args: Namespace,
supported_tasks: tuple["SupportedTask", ...] | None = None,
) -> None:
vllm_config = engine_client.vllm_config
if supported_tasks is None:
warnings.warn(
"The 'supported_tasks' parameter was not provided to "
"init_app_state and will be required in a future version. "
"Please pass 'supported_tasks' explicitly.",
DeprecationWarning,
stacklevel=2,
)
supported_tasks = _FALLBACK_SUPPORTED_TASKS
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config
state.args = args
resolved_chat_template = load_chat_template(args.chat_template)
# Merge default_mm_loras into the static lora_modules
default_mm_loras = (
vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None
else {}
)
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
)
await state.openai_serving_models.init_static_loras()
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks:
from vllm.entrypoints.openai.generate.api_router import init_generate_state
await init_generate_state(
engine_client, state, args, request_logger, supported_tasks
)
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.speech_to_text.api_router import (
init_transcription_state,
)
init_transcription_state(
engine_client, state, args, request_logger, supported_tasks
)
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import init_realtime_state
init_realtime_state(engine_client, state, args, request_logger, supported_tasks)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import init_pooling_state
init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
family = socket.AF_INET
if is_valid_ipv6_address(addr[0]):
family = socket.AF_INET6
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(addr)
return sock
def create_server_unix_socket(path: str) -> socket.socket:
sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
sock.bind(path)
return sock
def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.list_registered()
if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
raise KeyError(
f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})"
)
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
@instrument(span_name="API server setup")
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
log_version_and_model(logger, VLLM_VERSION, args.model)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
validate_api_server_args(args)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
if args.uds:
sock = create_server_unix_socket(args.uds)
else:
sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit()
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
if args.uds:
listen_address = f"unix:{args.uds}"
else:
addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
# Add process-specific prefix to stdout and stderr.
decorate_logs("APIServer")
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
async def run_server_worker(
listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
# Get uvicorn log config (from file or with endpoint filter)
log_config = get_uvicorn_log_config(args)
if log_config is not None:
uvicorn_kwargs["log_config"] = log_config
async with build_async_engine_client(
args,
client_config=client_config,
) as engine_client:
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
app = build_app(args, supported_tasks)
await init_app_state(engine_client, app.state, args, supported_tasks)
logger.info(
"Starting vLLM API server %d on %s",
engine_client.vllm_config.parallel_config._api_process_rank,
listen_address,
)
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
ssl_ciphers=args.ssl_ciphers,
h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
h11_max_header_count=args.h11_max_header_count,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
try:
await shutdown_task
finally:
sock.close()
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints.
cli_env_setup()
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server."
)
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.orca_metrics import metrics_header
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
def chat(request: Request) -> OpenAIServingChat | None:
return request.app.state.openai_serving_chat
@router.post(
"/v1/chat/completions",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
metrics_header_format = raw_request.headers.get(
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
)
handler = chat(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Chat Completions API"
)
try:
generator = await handler.create_chat_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(
content=generator.model_dump(),
headers=metrics_header(metrics_header_format),
)
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/chat/completions/render",
dependencies=[Depends(validate_json_request)],
response_model=list,
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request):
"""Render chat completion request and return conversation and engine
prompts without generating."""
handler = chat(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Chat Completions API"
)
try:
result = await handler.render_chat_request(request)
except Exception as e:
result = handler.create_error_response(e)
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
return JSONResponse(content=result)
def attach_router(app: FastAPI):
app.include_router(router)

View File

@@ -0,0 +1,733 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json
import time
from typing import Annotated, Any, ClassVar, Literal
import torch
from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import Field, model_validator
from vllm.config import ModelConfig
from vllm.config.utils import replace
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
DeltaMessage,
FunctionCall,
FunctionDefinition,
LegacyStructuralTagResponseFormat,
OpenAIBaseModel,
StreamOptions,
StructuralTagResponseFormat,
ToolCall,
UsageInfo,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
)
from vllm.utils import random_uuid
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
class ChatMessage(OpenAIBaseModel):
role: str
content: str | None = None
refusal: str | None = None
annotations: OpenAIAnnotation | None = None
audio: OpenAIChatCompletionAudio | None = None
function_call: FunctionCall | None = None
tool_calls: list[ToolCall] = Field(default_factory=list)
# vLLM-specific fields that are not in OpenAI spec
reasoning: str | None = None
class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: list[int] | None = None
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
# Workaround: redefine fields name cache so that it's not
# shared with the super class.
field_names: ClassVar[set[str] | None] = None
top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
class ChatCompletionLogProbs(OpenAIBaseModel):
content: list[ChatCompletionLogProbsContent] | None = None
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: ChatCompletionLogProbs | None = None
# per OpenAI spec this is the default
finish_reason: str | None = "stop"
# not part of the OpenAI spec but included in vLLM for legacy reasons
stop_reason: int | str | None = None
# not part of the OpenAI spec but is useful for tracing the tokens
# in agent scenarios
token_ids: list[int] | None = None
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChatCompletionResponseChoice]
service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
system_fingerprint: str | None = None
usage: UsageInfo
# vLLM-specific fields that are not in OpenAI spec
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
prompt_token_ids: list[int] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: ChatCompletionLogProbs | None = None
finish_reason: str | None = None
stop_reason: int | str | None = None
# not part of the OpenAI spec but for tracing the tokens
token_ids: list[int] | None = None
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChatCompletionResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)
# not part of the OpenAI spec but for tracing the tokens
prompt_token_ids: list[int] | None = None
class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: list[ChatCompletionMessageParam]
model: str | None = None
frequency_penalty: float | None = 0.0
logit_bias: dict[str, float] | None = None
logprobs: bool | None = False
top_logprobs: int | None = 0
max_tokens: int | None = Field(
default=None,
deprecated="max_tokens is deprecated in favor of "
"the max_completion_tokens field",
)
max_completion_tokens: int | None = None
n: int | None = 1
presence_penalty: float | None = 0.0
response_format: AnyResponseFormat | None = None
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: str | list[str] | None = []
stream: bool | None = False
stream_options: StreamOptions | None = None
temperature: float | None = None
top_p: float | None = None
tools: list[ChatCompletionToolsParam] | None = None
tool_choice: (
Literal["none"]
| Literal["auto"]
| Literal["required"]
| ChatCompletionNamedToolChoiceParam
| None
) = "none"
reasoning_effort: Literal["low", "medium", "high"] | None = None
include_reasoning: bool = True
parallel_tool_calls: bool | None = True
# NOTE this will be ignored by vLLM
user: str | None = None
# --8<-- [start:chat-completion-sampling-params]
use_beam_search: bool = False
top_k: int | None = None
min_p: float | None = None
repetition_penalty: float | None = None
length_penalty: float = 1.0
stop_token_ids: list[int] | None = []
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = (
None
)
prompt_logprobs: int | None = None
allowed_token_ids: list[int] | None = None
bad_words: list[str] = Field(default_factory=list)
# --8<-- [end:chat-completion-sampling-params]
# --8<-- [start:chat-completion-extra-params]
echo: bool = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."
),
)
add_generation_prompt: bool = Field(
default=True,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message: bool = Field(
default=False,
description=(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
'This allows you to "prefill" part of the model\'s response for it. '
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
documents: list[dict[str, str]] | None = Field(
default=None,
description=(
"A list of dicts representing documents that will be accessible to "
"the model if it is performing RAG (retrieval-augmented generation)."
" If the template does not support RAG, this argument will have no "
"effect. We recommend that each document should be a dict containing "
'"title" and "text" keys.'
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
structured_outputs: StructuredOutputsParams | None = Field(
default=None,
description="Additional kwargs for structured outputs",
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
return_tokens_as_token_ids: bool | None = Field(
default=None,
description=(
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."
),
)
return_token_ids: bool | None = Field(
default=None,
description=(
"If specified, the result will include token IDs alongside the "
"generated text. In streaming mode, prompt_token_ids is included "
"only in the first chunk, and token_ids contains the delta tokens "
"for each chunk. This is useful for debugging or when you "
"need to map generated text back to input tokens."
),
)
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None,
description=(
"Additional request parameters with (list of) string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
documents=self.documents,
reasoning_effort=self.reasoning_effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
if self.max_completion_tokens is not None:
max_output_tokens: int | None = self.max_completion_tokens
max_output_tokens_param = "max_completion_tokens"
else:
max_output_tokens = self.max_tokens
max_output_tokens_param = "max_tokens"
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=max_output_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param=max_output_tokens_param,
)
# Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
"min_p": 0.0,
}
def to_beam_search_params(
self, max_tokens: int, default_sampling_params: dict
) -> BeamSearchParams:
n = self.n if self.n is not None else 1
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
max_tokens: int,
default_sampling_params: dict,
) -> SamplingParams:
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
)
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
response_format = self.response_format
if response_format is not None:
structured_outputs_kwargs = dict[str, Any]()
# Set structured output params for response format
if response_format.type == "json_object":
structured_outputs_kwargs["json_object"] = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
structured_outputs_kwargs["json"] = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
structured_outputs_kwargs["structural_tag"] = json.dumps(s_tag_obj)
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if len(structured_outputs_kwargs) > 0:
self.structured_outputs = (
StructuredOutputsParams(**structured_outputs_kwargs)
if self.structured_outputs is None
else replace(self.structured_outputs, **structured_outputs_kwargs)
)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=prompt_logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias,
bad_words=self.bad_words,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
)
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0 and top_logprobs != -1:
raise VLLMValidationError(
"`top_logprobs` must be a positive value or -1.",
parameter="top_logprobs",
value=top_logprobs,
)
if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
raise VLLMValidationError(
"when using `top_logprobs`, `logprobs` must be set to true.",
parameter="top_logprobs",
)
return data
@model_validator(mode="before")
@classmethod
def check_structured_outputs_count(cls, data):
if isinstance(data, ValueError):
raise data
if data.get("structured_outputs", None) is None:
return data
structured_outputs_kwargs = data["structured_outputs"]
# structured_outputs may arrive as a dict (from JSON/raw kwargs) or
# as a StructuredOutputsParams dataclass instance.
is_dataclass = isinstance(structured_outputs_kwargs, StructuredOutputsParams)
count = sum(
(
getattr(structured_outputs_kwargs, k, None)
if is_dataclass
else structured_outputs_kwargs.get(k)
)
is not None
for k in ("json", "regex", "choice")
)
# you can only use one kind of constraints for structured outputs
if count > 1:
raise ValueError(
"You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice')."
)
# you can only either use structured outputs or tools, not both
if count > 1 and data.get("tool_choice", "none") not in (
"none",
"auto",
"required",
):
raise ValueError(
"You can only either use constraints for structured outputs "
"or tools, not both."
)
return data
@model_validator(mode="before")
@classmethod
def check_tool_usage(cls, data):
# if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice
if "tool_choice" not in data and data.get("tools"):
data["tool_choice"] = "auto"
# if "tool_choice" is "none" -- no validation is needed for tools
if "tool_choice" in data and data["tool_choice"] == "none":
return data
# if "tool_choice" is specified -- validation
if "tool_choice" in data and data["tool_choice"] is not None:
# ensure that if "tool choice" is specified, tools are present
if "tools" not in data or data["tools"] is None:
raise ValueError("When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto" or "required"
if data["tool_choice"] not in ["auto", "required"] and not isinstance(
data["tool_choice"], dict
):
raise ValueError(
f"Invalid value for `tool_choice`: {data['tool_choice']}! "
'Only named tools, "none", "auto" or "required" '
"are supported."
)
# if tool_choice is "required" but the "tools" list is empty,
# override the data to behave like "none" to align with
# OpenAIs behavior.
if (
data["tool_choice"] == "required"
and isinstance(data["tools"], list)
and len(data["tools"]) == 0
):
data["tool_choice"] = "none"
del data["tools"]
return data
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
correct_usage_message = (
'Correct usage: `{"type": "function",'
' "function": {"name": "my_function"}}`'
)
if isinstance(data["tool_choice"], dict):
valid_tool = False
function = data["tool_choice"].get("function")
if not isinstance(function, dict):
raise ValueError(
f"Invalid value for `function`: `{function}` in "
f"`tool_choice`! {correct_usage_message}"
)
if "name" not in function:
raise ValueError(
f"Expected field `name` in `function` in "
f"`tool_choice`! {correct_usage_message}"
)
function_name = function["name"]
if not isinstance(function_name, str) or len(function_name) == 0:
raise ValueError(
f"Invalid `name` in `function`: `{function_name}`"
f" in `tool_choice`! {correct_usage_message}"
)
for tool in data["tools"]:
if tool["function"]["name"] == function_name:
valid_tool = True
break
if not valid_tool:
raise ValueError(
"The tool specified in `tool_choice` does not match any"
" of the specified `tools`"
)
return data
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data
@model_validator(mode="before")
@classmethod
def check_system_message_content_type(cls, data):
"""Warn if system messages contain non-text content.
According to OpenAI API spec, system messages can only be of type
'text'. We log a warning instead of rejecting to avoid breaking
users who intentionally send multimodal system messages.
See: https://platform.openai.com/docs/api-reference/chat/create#chat_create-messages-system_message
"""
if not isinstance(data, dict):
return data
messages = data.get("messages", [])
for msg in messages:
# Check if this is a system message
if isinstance(msg, dict) and msg.get("role") == "system":
content = msg.get("content")
# If content is a list (multimodal format)
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
part_type = part.get("type")
# Infer type when 'type' field is not explicit
if part_type is None:
if "image_url" in part or "image_pil" in part:
part_type = "image_url"
elif "image_embeds" in part:
part_type = "image_embeds"
elif "audio_url" in part:
part_type = "audio_url"
elif "input_audio" in part:
part_type = "input_audio"
elif "audio_embeds" in part:
part_type = "audio_embeds"
elif "video_url" in part:
part_type = "video_url"
# Warn about non-text content in system messages
if part_type and part_type != "text":
logger.warning_once(
"System messages should only contain text "
"content according to the OpenAI API spec. "
"Found content type: '%s'.",
part_type,
)
return data

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,171 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Harmony-specific streaming delta extraction for chat completions.
This module handles the extraction of DeltaMessage objects from
harmony parser state during streaming chat completions.
"""
from typing import NamedTuple
from openai_harmony import StreamableParser
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
class TokenState(NamedTuple):
channel: str | None
recipient: str | None
text: str
def extract_harmony_streaming_delta(
harmony_parser: StreamableParser,
token_states: list[TokenState],
prev_recipient: str | None,
include_reasoning: bool,
) -> tuple[DeltaMessage | None, bool]:
"""
Extract a DeltaMessage from harmony parser state during streaming.
Args:
harmony_parser: The StreamableParser instance tracking parse state
token_states: List of TokenState tuples for each token
prev_recipient: Previous recipient for detecting tool call transitions
include_reasoning: Whether to include reasoning content
Returns:
A tuple of (DeltaMessage or None, tools_streamed_flag)
"""
if not token_states:
return None, False
tools_streamed = False
# Group consecutive tokens with same channel/recipient
groups: list[TokenState] = []
current_channel = token_states[0].channel
current_recipient = token_states[0].recipient
current_text = token_states[0].text
for i in range(1, len(token_states)):
state = token_states[i]
if state.channel == current_channel and state.recipient == current_recipient:
current_text += state.text
else:
groups.append(TokenState(current_channel, current_recipient, current_text))
current_channel = state.channel
current_recipient = state.recipient
current_text = state.text
groups.append(TokenState(current_channel, current_recipient, current_text))
# Process each group and create delta messages
delta_message = None
combined_content = ""
combined_reasoning = ""
tool_messages = []
content_encountered = False
# Calculate base_index once before the loop
# This counts completed tool calls in messages
base_index = 0
for msg in harmony_parser.messages:
if (
(msg.channel == "commentary" or msg.channel == "analysis")
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1
# If there's an ongoing tool call from previous chunk,
# the next new tool call starts at base_index + 1
if prev_recipient and prev_recipient.startswith("functions."):
next_tool_index = base_index + 1
# Ongoing call is at base_index
ongoing_tool_index = base_index
else:
# No ongoing call, next new call is at base_index
next_tool_index = base_index
ongoing_tool_index = None
for group in groups:
if group.channel == "final":
combined_content += group.text
content_encountered = True
elif (
(group.channel == "commentary" or group.channel == "analysis")
and group.recipient
and group.recipient.startswith("functions.")
):
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments="",
),
index=next_tool_index,
)
)
opened_new_call = True
prev_recipient = group.recipient
# Increment for subsequent new tool calls
next_tool_index += 1
if group.text:
# Stream arguments for the ongoing tool call
if opened_new_call:
# Just opened in this group
tool_call_index = next_tool_index - 1
else:
# Continuing from previous chunk
# If ongoing_tool_index is None here, it means
# we're continuing a call but prev_recipient
# wasn't a function. Use base_index.
tool_call_index = (
ongoing_tool_index
if ongoing_tool_index is not None
else base_index
)
tool_messages.append(
DeltaToolCall(
index=tool_call_index,
function=DeltaFunctionCall(arguments=group.text),
)
)
elif group.channel == "commentary" and group.recipient is None:
# Tool call preambles meant to be shown to the user
combined_content += group.text
content_encountered = True
elif group.channel == "analysis" and include_reasoning:
combined_reasoning += group.text
# Combine all non-empty fields into a single message
if content_encountered or combined_reasoning or tool_messages:
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
if content_encountered:
delta_kwargs["content"] = combined_content
if combined_reasoning:
delta_kwargs["reasoning"] = combined_reasoning
if tool_messages:
delta_kwargs["tool_calls"] = tool_messages
tools_streamed = True
delta_message = DeltaMessage(**delta_kwargs)
else:
delta_message = None
return delta_message, tools_streamed

View File

@@ -0,0 +1,372 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from collections.abc import Sequence
from dataclasses import field
from typing import Any, Literal
import vllm.envs as envs
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
validate_chat_template,
)
from vllm.entrypoints.constants import (
H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.logger import init_logger
from vllm.tool_parsers import ToolParserManager
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
class LoRAParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: str | Sequence[str] | None,
option_string: str | None = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: list[LoRAModulePath] = []
for item in values:
if item in [None, ""]: # Skip if item is None or empty string
continue
if "=" in item and "," not in item: # Old format: name=path
name, path = item.split("=")
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
@config
class BaseFrontendArgs:
"""Base arguments for the OpenAI-compatible frontend server.
This base class does not include host, port, and server-specific arguments
like SSL, CORS, and HTTP server settings. Those arguments are added by
the subclasses.
"""
lora_modules: list[LoRAModulePath] | None = None
"""LoRA modules configurations in either 'name=path' format or JSON format
or JSON list format. Example (old format): `'name=path'` Example (new
format): `{\"name\": \"name\", \"path\": \"lora_path\",
\"base_model_name\": \"id\"}`"""
chat_template: str | None = None
"""The file path to the chat template, or the template in single-line form
for the specified model."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
default_chat_template_kwargs: dict[str, Any] | None = None
"""Default keyword arguments to pass to the chat template renderer.
These will be merged with request-level chat_template_kwargs,
with request values taking precedence. Useful for setting default
behavior for reasoning models. Example: '{"enable_thinking": false}'
to disable thinking mode by default for Qwen3/DeepSeek models."""
response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`."""
return_tokens_as_token_ids: bool = False
"""When `--max-logprobs` is specified, represents single tokens as
strings of the form 'token_id:{token_id}' so that tokens that are not
JSON-encodable can be identified."""
disable_frontend_multiprocessing: bool = False
"""If specified, will run the OpenAI frontend server in the same process as
the model serving engine."""
enable_auto_tool_choice: bool = False
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
to specify which parser to use."""
exclude_tools_when_tool_choice_none: bool = False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
tool_call_parser: str | None = None
"""Select the tool call parser depending on the model that you're using.
This is used to parse the model-generated tool call into OpenAI API format.
Required for `--enable-auto-tool-choice`. You can choose any option from
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
tool_parser_plugin: str = ""
"""Special the tool parser plugin write to parse the model-generated tool
into OpenAI API format, the name register in this plugin can be used in
`--tool-call-parser`."""
tool_server: str | None = None
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
purpose."""
log_config_file: str | None = envs.VLLM_LOGGING_CONFIG_PATH
"""Path to logging config JSON file for both vllm and uvicorn"""
max_log_len: int | None = None
"""Max number of prompt characters or prompt ID numbers being printed in
log. The default of None means unlimited."""
enable_prompt_tokens_details: bool = False
"""If set to True, enable prompt_tokens_details in usage."""
enable_server_load_tracking: bool = False
"""If set to True, enable tracking server_load_metrics in the app state."""
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the `/tokenizer_info` endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If set to True, log model outputs (generations).
Requires --enable-log-requests."""
enable_log_deltas: bool = True
"""If set to False, output deltas will not be logged. Relevant only if
--enable-log-outputs is set.
"""
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
"""If set to True, log the stack trace of error responses"""
tokens_only: bool = False
"""
If set to True, only enable the Tokens In<>Out endpoint.
This is intended for use in a Disaggregated Everything setup.
"""
@classmethod
def _customize_cli_kwargs(
cls,
frontend_kwargs: dict[str, Any],
) -> dict[str, Any]:
"""Customize argparse kwargs before arguments are registered.
Subclasses should override this and call
``super()._customize_cli_kwargs(frontend_kwargs)`` first.
"""
# Special case: default_chat_template_kwargs needs json.loads type
frontend_kwargs["default_chat_template_kwargs"]["type"] = json.loads
# Special case: LoRA modules need custom parser action and
# optional_type(str)
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers)
frontend_kwargs["tool_call_parser"]["metavar"] = (
f"{{{parsers_str}}} or name registered in --tool-parser-plugin"
)
return frontend_kwargs
@classmethod
def add_cli_args(cls, parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Register CLI arguments for this frontend class.
Subclasses should override ``_customize_cli_kwargs`` instead of
this method so that base-class postprocessing is always applied.
"""
from vllm.engine.arg_utils import get_kwargs
frontend_kwargs = get_kwargs(cls)
frontend_kwargs = cls._customize_cli_kwargs(frontend_kwargs)
group_name = cls.__name__.replace("Args", "")
frontend_group = parser.add_argument_group(
title=group_name,
description=cls.__doc__,
)
for key, value in frontend_kwargs.items():
extra_flags = value.pop("flags", [])
frontend_group.add_argument(
*extra_flags, f"--{key.replace('_', '-')}", **value
)
return parser
@config
class FrontendArgs(BaseFrontendArgs):
"""Arguments for the OpenAI-compatible frontend server."""
host: str | None = None
"""Host name."""
port: int = 8000
"""Port number."""
uds: str | None = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal[
"critical", "error", "warning", "info", "debug", "trace"
] = "info"
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
"""Allowed origins."""
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
"""Allowed methods."""
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
"""Allowed headers."""
api_key: list[str] | None = None
"""If provided, the server will require one of these keys to be presented in
the header."""
ssl_keyfile: str | None = None
"""The file path to the SSL key file."""
ssl_certfile: str | None = None
"""The file path to the SSL cert file."""
ssl_ca_certs: str | None = None
"""The CA certificates file."""
enable_ssl_refresh: bool = False
"""Refresh SSL Context when SSL certificate files change"""
ssl_cert_reqs: int = int(ssl.CERT_NONE)
"""Whether client certificate is required (see stdlib ssl module's)."""
ssl_ciphers: str | None = None
"""SSL cipher suites for HTTPS (TLS 1.2 and below only).
Example: 'ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305'"""
root_path: str | None = None
"""FastAPI root_path when app is behind a path based routing proxy."""
middleware: list[str] = field(default_factory=lambda: [])
"""Additional ASGI middleware to apply to the app. We accept multiple
--middleware arguments. The value should be an import path. If a function
is provided, vLLM will add it to the server using
`@app.middleware('http')`. If a class is provided, vLLM will
add it to the server using `app.add_middleware()`."""
enable_request_id_headers: bool = False
"""If specified, API server will add X-Request-Id header to responses."""
disable_fastapi_docs: bool = False
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
"""Maximum number of HTTP headers allowed in a request for h11 parser.
Helps mitigate header abuse. Default: 256."""
enable_offline_docs: bool = False
"""
Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM.
"""
@classmethod
def _customize_cli_kwargs(
cls,
frontend_kwargs: dict[str, Any],
) -> dict[str, Any]:
frontend_kwargs = super()._customize_cli_kwargs(frontend_kwargs)
# Special case: allowed_origins, allowed_methods, allowed_headers all
# need json.loads type
# Should also remove nargs
frontend_kwargs["allowed_origins"]["type"] = json.loads
frontend_kwargs["allowed_methods"]["type"] = json.loads
frontend_kwargs["allowed_headers"]["type"] = json.loads
del frontend_kwargs["allowed_origins"]["nargs"]
del frontend_kwargs["allowed_methods"]["nargs"]
del frontend_kwargs["allowed_headers"]["nargs"]
# Special case: Middleware needs to append action
frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str
if "nargs" in frontend_kwargs["middleware"]:
del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = []
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
return frontend_kwargs
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Create the CLI argument parser used by the OpenAI API server.
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
register all arguments instead of manually enumerating them here. This
avoids code duplication and keeps the argument definitions in one place.
"""
parser.add_argument(
"model_tag",
type=str,
nargs="?",
help="The model tag to serve (optional if specified in config)",
)
parser.add_argument(
"--headless",
action="store_true",
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.",
)
parser.add_argument(
"--api-server-count",
"-asc",
type=int,
default=None,
help="How many API server processes to run. "
"Defaults to data_parallel_size if not specified.",
)
parser.add_argument(
"--config",
help="Read CLI options from a config file. "
"Must be a YAML with the following options: "
"https://docs.vllm.ai/en/latest/configuration/serve_args.html",
)
parser = FrontendArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser")
if args.enable_log_outputs and not args.enable_log_requests:
raise TypeError("Error: --enable-log-outputs requires --enable-log-requests")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server"
)
return make_arg_parser(parser_for_docs)

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
CompletionResponse,
)
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.orca_metrics import metrics_header
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
def completion(request: Request) -> OpenAIServingCompletion | None:
return request.app.state.openai_serving_completion
@router.post(
"/v1/completions",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request):
metrics_header_format = raw_request.headers.get(
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
)
handler = completion(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Completions API"
)
try:
generator = await handler.create_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, CompletionResponse):
return JSONResponse(
content=generator.model_dump(),
headers=metrics_header(metrics_header_format),
)
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/completions/render",
dependencies=[Depends(validate_json_request)],
response_model=list,
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
async def render_completion(request: CompletionRequest, raw_request: Request):
"""render completion request and return engine prompts without generating."""
handler = completion(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Completions API"
)
try:
result = await handler.render_completion_request(request)
except Exception as e:
result = handler.create_error_response(e)
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
return JSONResponse(content=result)
def attach_router(app: FastAPI):
app.include_router(router)

View File

@@ -0,0 +1,475 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json
import time
from typing import Annotated, Any, Literal
import torch
from pydantic import Field, model_validator
from vllm.config import ModelConfig
from vllm.config.utils import replace
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
LegacyStructuralTagResponseFormat,
OpenAIBaseModel,
StreamOptions,
StructuralTagResponseFormat,
UsageInfo,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
)
from vllm.utils import random_uuid
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str | None = None
prompt: (
list[Annotated[int, Field(ge=0)]]
| list[list[Annotated[int, Field(ge=0)]]]
| str
| list[str]
| None
) = None
echo: bool | None = False
frequency_penalty: float | None = 0.0
logit_bias: dict[str, float] | None = None
logprobs: int | None = None
max_tokens: int | None = 16
n: int = 1
presence_penalty: float | None = 0.0
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: str | list[str] | None = []
stream: bool | None = False
stream_options: StreamOptions | None = None
suffix: str | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
# --8<-- [start:completion-sampling-params]
use_beam_search: bool = False
top_k: int | None = None
min_p: float | None = None
repetition_penalty: float | None = None
length_penalty: float = 1.0
stop_token_ids: list[int] | None = []
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = (
None
)
allowed_token_ids: list[int] | None = None
prompt_logprobs: int | None = None
# --8<-- [end:completion-sampling-params]
# --8<-- [start:completion-extra-params]
prompt_embeds: bytes | list[bytes] | None = None
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
response_format: AnyResponseFormat | None = Field(
default=None,
description=(
"Similar to chat completion, this parameter specifies the format "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
structured_outputs: StructuredOutputsParams | None = Field(
default=None,
description="Additional kwargs for structured outputs",
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
return_tokens_as_token_ids: bool | None = Field(
default=None,
description=(
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."
),
)
return_token_ids: bool | None = Field(
default=None,
description=(
"If specified, the result will include token IDs alongside the "
"generated text. In streaming mode, prompt_token_ids is included "
"only in the first chunk, and token_ids contains the delta tokens "
"for each chunk. This is useful for debugging or when you "
"need to map generated text back to input tokens."
),
)
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
vllm_xargs: dict[str, str | int | float] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_tokens",
)
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
"min_p": 0.0,
}
def to_beam_search_params(
self,
max_tokens: int,
default_sampling_params: dict | None = None,
) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
max_tokens: int,
default_sampling_params: dict | None = None,
) -> SamplingParams:
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
)
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0
response_format = self.response_format
if response_format is not None:
structured_outputs_kwargs = dict[str, Any]()
# Set structured output params for response format
if response_format.type == "json_object":
structured_outputs_kwargs["json_object"] = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
structured_outputs_kwargs["json"] = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
structured_outputs_kwargs["structural_tag"] = json.dumps(s_tag_obj)
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if len(structured_outputs_kwargs) > 0:
self.structured_outputs = (
StructuredOutputsParams(**structured_outputs_kwargs)
if self.structured_outputs is None
else replace(self.structured_outputs, **structured_outputs_kwargs)
)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def check_structured_outputs_count(cls, data):
if data.get("structured_outputs", None) is None:
return data
structured_outputs_kwargs = data["structured_outputs"]
# structured_outputs may arrive as a dict (from JSON/raw kwargs) or
# as a StructuredOutputsParams dataclass instance.
is_dataclass = isinstance(structured_outputs_kwargs, StructuredOutputsParams)
count = sum(
(
getattr(structured_outputs_kwargs, k, None)
if is_dataclass
else structured_outputs_kwargs.get(k)
)
is not None
for k in ("json", "regex", "choice")
)
if count > 1:
raise VLLMValidationError(
"You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice').",
parameter="structured_outputs",
)
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
)
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise VLLMValidationError(
"`logprobs` must be a positive value.",
parameter="logprobs",
value=logprobs,
)
return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data
@model_validator(mode="before")
@classmethod
def validate_prompt_and_prompt_embeds(cls, data):
prompt = data.get("prompt")
prompt_embeds = data.get("prompt_embeds")
prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
embeds_is_empty = prompt_embeds is None or (
isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
)
if prompt_is_empty and embeds_is_empty:
raise ValueError(
"Either prompt or prompt_embeds must be provided and non-empty."
)
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data
class CompletionLogProbs(OpenAIBaseModel):
text_offset: list[int] = Field(default_factory=list)
token_logprobs: list[float | None] = Field(default_factory=list)
tokens: list[str] = Field(default_factory=list)
top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: CompletionLogProbs | None = None
finish_reason: str | None = None
stop_reason: int | str | None = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
token_ids: list[int] | None = None # For response
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
prompt_token_ids: list[int] | None = None # For prompt
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: Literal["text_completion"] = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseChoice]
service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
system_fingerprint: str | None = None
usage: UsageInfo
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: CompletionLogProbs | None = None
finish_reason: str | None = None
stop_reason: int | str | None = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
# not part of the OpenAI spec but for tracing the tokens
# prompt tokens is put into choice to align with CompletionResponseChoice
prompt_token_ids: list[int] | None = None
token_ids: list[int] | None = None
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)

View File

@@ -0,0 +1,681 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import cast
import jinja2
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.completion.protocol import (
CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
logger = init_logger(__name__)
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param()
mc = self.model_config
self.override_max_tokens = (
self.default_sampling_params.get("max_tokens")
if mc.generation_config not in ("auto", "vllm")
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
)
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse:
"""
render completion request by validating and preprocessing inputs.
Returns:
A list of engine_prompts on success,
or an ErrorResponse on failure.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response("suffix is not currently supported")
if request.echo and request.prompt_embeds is not None:
return self.create_error_response("Echo is unsupported with prompt embeds.")
if request.prompt_logprobs is not None and request.prompt_embeds is not None:
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds."
)
try:
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
return engine_prompts
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
result = await self.render_completion_request(request)
if isinstance(result, ErrorResponse):
return result
engine_prompts = result
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
max_tokens = get_max_tokens(
max_model_len,
request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# We do not stream the results when using beam search.
stream = request.stream and not request.use_beam_search
# Streaming response
tokenizer = self.renderer.tokenizer
if stream:
return self.completion_stream_generator(
request,
engine_prompts,
result_generator,
request_id,
created_time,
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
)
# Non-streaming response
final_res_batch: list[RequestOutput | None] = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
return self.create_error_response(e)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[ProcessorInputs],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options
include_usage, include_continuous_usage = should_include_usage(
stream_options, self.enable_force_include_usage
)
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = self._extract_prompt_text(engine_prompt)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# Useful when request.return_token_ids is True
# Returning prompt token IDs shares the same logic
# with the echo implementation.
prompt_token_ids_to_return: list[int] | None = None
assert request.max_tokens is not None
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
else:
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids,
*output.token_ids,
]
out_logprobs = [
*(prompt_logprobs or []),
*(output.logprobs or []),
]
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
# has_echoed[i] is reused here to indicate whether
# we have already returned the prompt token IDs.
if not has_echoed[i] and request.return_token_ids:
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
if (
not delta_text
and not delta_token_ids
and not previous_num_tokens[i]
):
# Chunked prefill case, don't return empty chunks
continue
if request.logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=previous_text_lens[i],
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
self._raise_if_error(finish_reason, request_id)
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
prompt_token_ids=prompt_token_ids_to_return,
token_ids=(
as_list(output.token_ids)
if request.return_token_ids
else None
),
)
],
)
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens
)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=final_usage_info,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: list[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch:
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
assert request.max_tokens is not None
if request.echo:
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
else:
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
prompt_token_ids=(
prompt_token_ids if request.return_token_ids else None
),
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if (
self.enable_prompt_tokens_details
and last_final_res
and last_final_res.num_cached_tokens
):
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=last_final_res.num_cached_tokens
)
request_metadata.final_usage_info = usage
if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=kv_transfer_params,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int,
tokenizer: TokenizerLike | None,
initial_text_offset: int = 0,
return_as_token_id: bool | None = None,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: list[int] = []
out_token_logprobs: list[float | None] = []
out_tokens: list[str] = []
out_top_logprobs: list[dict[str, float] | None] = []
last_token_len = 0
should_return_as_token_id = (
return_as_token_id
if return_as_token_id is not None
else self.return_tokens_as_token_ids
)
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
if should_return_as_token_id:
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise VLLMValidationError(
"Unable to get tokenizer because "
"`skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_token,
token_id,
tokenizer,
return_as_token_id=should_return_as_token_id,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append(
{
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
): max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
}
)
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,312 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Any, ClassVar, Literal, TypeAlias
import regex as re
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
)
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class OpenAIBaseModel(BaseModel):
# OpenAI API does allow extra fields
model_config = ConfigDict(extra="allow")
# Cache class field names
field_names: ClassVar[set[str] | None] = None
@model_validator(mode="wrap")
@classmethod
def __log_extra_fields__(cls, data, handler):
result = handler(data)
if not isinstance(data, dict):
return result
field_names = cls.field_names
if field_names is None:
# Get all class field names and their potential aliases
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if alias := getattr(field, "alias", None):
field_names.add(alias)
cls.field_names = field_names
# Compare against both field names and aliases
if any(k not in field_names for k in data):
logger.warning(
"The following fields were present in the request but ignored: %s",
data.keys() - field_names,
)
return result
class ErrorInfo(OpenAIBaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(OpenAIBaseModel):
error: ErrorInfo
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: str | None = None
is_blocking: bool = False
class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: str | None = None
parent: str | None = None
max_model_len: int | None = None
permission: list[ModelPermission] = Field(default_factory=list)
class ModelList(OpenAIBaseModel):
object: str = "list"
data: list[ModelCard] = Field(default_factory=list)
class PromptTokenUsageInfo(OpenAIBaseModel):
cached_tokens: int | None = None
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: int | None = 0
prompt_tokens_details: PromptTokenUsageInfo | None = None
class RequestResponseMetadata(BaseModel):
request_id: str
final_usage_info: UsageInfo | None = None
class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str
description: str | None = None
# schema is the field in openai but that causes conflicts with pydantic so
# instead use json_schema with an alias
json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
strict: bool | None = None
class LegacyStructuralTag(OpenAIBaseModel):
begin: str
# schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias
structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
end: str
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
structures: list[LegacyStructuralTag]
triggers: list[str]
class StructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
format: Any
AnyStructuralTagResponseFormat: TypeAlias = (
LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)
class ResponseFormat(OpenAIBaseModel):
# type must be "json_schema", "json_object", or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: JsonSchemaResponseFormat | None = None
AnyResponseFormat: TypeAlias = (
ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
class StreamOptions(OpenAIBaseModel):
include_usage: bool | None = True
continuous_usage_stats: bool | None = False
class FunctionDefinition(OpenAIBaseModel):
name: str
description: str | None = None
parameters: dict[str, Any] | None = None
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
class LogitsProcessorConstructor(BaseModel):
qualname: str
args: list[Any] | None = None
kwargs: dict[str, Any] | None = None
model_config = ConfigDict(extra="forbid")
LogitsProcessors = list[str | LogitsProcessorConstructor]
def get_logits_processors(
processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
if processors and pattern:
logits_processors = []
for processor in processors:
qualname = processor if isinstance(processor, str) else processor.qualname
if not re.match(pattern, qualname):
raise ValueError(
f"Logits processor '{qualname}' is not allowed by this "
"server. See --logits-processor-pattern engine argument "
"for more information."
)
try:
logits_processor = resolve_obj_by_qualname(qualname)
except Exception as e:
raise ValueError(
f"Logits processor '{qualname}' could not be resolved: {e}"
) from e
if isinstance(processor, LogitsProcessorConstructor):
logits_processor = logits_processor(
*processor.args or [], **processor.kwargs or {}
)
logits_processors.append(logits_processor)
return logits_processors
elif processors:
raise ValueError(
"The `logits_processors` argument is not supported by this "
"server. See --logits-processor-pattern engine argument "
"for more information."
)
return None
class FunctionCall(OpenAIBaseModel):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id: str | None = Field(default=None, exclude=True)
name: str
arguments: str
class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=make_tool_call_id)
type: Literal["function"] = "function"
function: FunctionCall
class DeltaFunctionCall(BaseModel):
name: str | None = None
arguments: str | None = None
# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
id: str | None = None
type: Literal["function"] | None = None
index: int
function: DeltaFunctionCall | None = None
class ExtractedToolCallInformation(BaseModel):
# indicate if tools were called
tools_called: bool
# extracted tool calls
tool_calls: list[ToolCall]
# content - per OpenAI spec, content AND tool calls can be returned rarely
# But some models will do this intentionally
content: str | None = None
class DeltaMessage(OpenAIBaseModel):
role: str | None = None
content: str | None = None
reasoning: str | None = None
tool_calls: list[DeltaToolCall] = Field(default_factory=list)
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import FastAPI
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
def register_generate_api_routers(app: FastAPI):
from vllm.entrypoints.openai.chat_completion.api_router import (
attach_router as register_chat_api_router,
)
register_chat_api_router(app)
from vllm.entrypoints.openai.responses.api_router import (
attach_router as register_responses_api_router,
)
register_responses_api_router(app)
from vllm.entrypoints.openai.completion.api_router import (
attach_router as register_completion_api_router,
)
register_completion_api_router(app)
from vllm.entrypoints.anthropic.api_router import (
attach_router as register_anthropic_api_router,
)
register_anthropic_api_router(app)
async def init_generate_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.mcp.tool_server import (
DemoToolServer,
MCPToolServer,
ToolServer,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.serve.disagg.serving import ServingTokens
if args.tool_server == "demo":
tool_server: ToolServer | None = DemoToolServer()
assert isinstance(tool_server, DemoToolServer)
await tool_server.init_and_validate()
elif args.tool_server:
tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server)
else:
tool_server = None
resolved_chat_template = load_chat_template(args.chat_template)
state.openai_serving_responses = (
OpenAIServingResponses(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
tool_server=tool_server,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.openai_serving_chat = (
OpenAIServingChat(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
# Warm up chat template processing to avoid first-request latency
if state.openai_serving_chat is not None:
await state.openai_serving_chat.warmup()
state.openai_serving_completion = (
OpenAIServingCompletion(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.anthropic_serving_messages = (
AnthropicServingMessages(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None
)
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
)
if "generate" in supported_tasks
else None
)

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import JSONResponse
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def models(request: Request) -> OpenAIServingModels:
return request.app.state.openai_serving_models
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = models(raw_request)
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
def attach_router(app: FastAPI):
app.include_router(router)

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: str | None = None

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from asyncio import Lock
from collections import defaultdict
from http import HTTPStatus
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath, LoRAModulePath
from vllm.entrypoints.serve.lora.protocol import (
LoadLoRAAdapterRequest,
UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.utils import sanitize_message
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.utils.counter import AtomicCounter
logger = init_logger(__name__)
class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.
Handles the routes:
- /v1/models
- /v1/load_lora_adapter
- /v1/unload_lora_adapter
"""
def __init__(
self,
engine_client: EngineClient,
base_model_paths: list[BaseModelPath],
*,
lora_modules: list[LoRAModulePath] | None = None,
):
super().__init__()
self.engine_client = engine_client
self.base_model_paths = base_model_paths
self.static_lora_modules = lora_modules
self.lora_requests: dict[str, LoRARequest] = {}
self.lora_id_counter = AtomicCounter(0)
self.lora_resolvers: list[LoRAResolver] = []
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers():
self.lora_resolvers.append(
LoRAResolverRegistry.get_resolver(lora_resolver_name)
)
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.model_config = self.engine_client.model_config
self.renderer = self.engine_client.renderer
self.io_processor = self.engine_client.io_processor
self.input_processor = self.engine_client.input_processor
async def init_static_loras(self):
"""Loads all static LoRA modules.
Raises if any fail to load"""
if self.static_lora_modules is None:
return
for lora in self.static_lora_modules:
load_request = LoadLoRAAdapterRequest(
lora_path=lora.path, lora_name=lora.name
)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name
)
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
def model_name(self, lora_request: LoRARequest | None = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all adapters."""
max_model_len = self.model_config.max_model_len
model_cards = [
ModelCard(
id=base_model.name,
max_model_len=max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(
id=lora.lora_name,
root=lora.path,
parent=lora.base_model_name
if lora.base_model_name
else self.base_model_paths[0].name,
permission=[ModelPermission()],
)
for lora in self.lora_requests.values()
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
async def load_lora_adapter(
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
) -> ErrorResponse | str:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_path = request.lora_path
lora_int_id = (
self.lora_requests[lora_name].lora_int_id
if lora_name in self.lora_requests
else self.lora_id_counter.inc(1)
)
lora_request = LoRARequest(
lora_name=lora_name,
lora_int_id=lora_int_id,
lora_path=lora_path,
load_inplace=request.load_inplace,
)
if base_model_name is not None and self.is_base_model(base_model_name):
lora_request.base_model_name = base_model_name
# Validate that the adapter can be loaded into the engine
# This will also preload it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except Exception as e:
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
return create_error_response(
message=str(e), err_type=error_type, status_code=status_code
)
self.lora_requests[lora_name] = lora_request
logger.info(
"Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path
)
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self, request: UnloadLoRAAdapterRequest
) -> ErrorResponse | str:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_unload_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
# Safe to delete now since we hold the lock
del self.lora_requests[lora_name]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoRAAdapterRequest
) -> ErrorResponse | None:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
# If not loading inplace
# Check if the lora adapter with the given name already exists
if not request.load_inplace and request.lora_name in self.lora_requests:
return create_error_response(
message=f"The lora adapter '{request.lora_name}' has already been "
"loaded. If you want to load the adapter in place, set 'load_inplace'"
" to True.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
async def _check_unload_lora_adapter_request(
self, request: UnloadLoRAAdapterRequest
) -> ErrorResponse | None:
# Check if 'lora_name' is not provided return an error
if not request.lora_name:
return create_error_response(
message="'lora_name' needs to be provided to unload a LoRA adapter.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
# Check if the lora adapter with the given name exists
if request.lora_name not in self.lora_requests:
return create_error_response(
message=f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
)
return None
async def resolve_lora(self, lora_name: str) -> LoRARequest | ErrorResponse:
"""Attempt to resolve a LoRA adapter using available resolvers.
Args:
lora_name: Name/identifier of the LoRA adapter
Returns:
LoRARequest if found and loaded successfully.
ErrorResponse (404) if no resolver finds the adapter.
ErrorResponse (400) if adapter(s) are found but none load.
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
if lora_name in self.lora_requests:
return self.lora_requests[lora_name]
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
found_adapter = False
# Try to resolve using available resolvers
for resolver in self.lora_resolvers:
lora_request = await resolver.resolve_lora(base_model_name, lora_name)
if lora_request is not None:
found_adapter = True
lora_request.lora_int_id = unique_id
try:
await self.engine_client.add_lora(lora_request)
self.lora_requests[lora_name] = lora_request
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name,
resolver.__class__.__name__,
)
return lora_request
except BaseException as e:
logger.warning(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver.",
lora_name,
resolver.__class__.__name__,
e,
)
continue
if found_adapter:
# An adapter was found, but all attempts to load it failed.
return create_error_response(
message=(
f"LoRA adapter '{lora_name}' was found but could not be loaded."
),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST,
)
else:
# No adapter was found
return create_error_response(
message=f"LoRA adapter {lora_name} does not exist",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
)
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
return ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
)
)

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utility functions that create ORCA endpoint load report response headers.
"""
import json
from collections.abc import Mapping
from vllm.logger import init_logger
from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot
logger = init_logger(__name__)
def create_orca_header(
metrics_format: str, named_metrics: list[tuple[str, float]]
) -> Mapping[str, str] | None:
"""
Creates ORCA headers named 'endpoint-load-metrics' in the specified format
and adds custom metrics to named_metrics.
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
Parameters:
- metrics_format (str): The format of the header ('TEXT', 'JSON').
- named_metrics (List[Tuple[str, float]]): List of tuples with metric names
and their corresponding double values.
Returns:
- Optional[Mapping[str,str]]: A dictionary with header key as
'endpoint-load-metrics' and values as the ORCA header strings with
format prefix and data in with named_metrics in.
"""
if metrics_format.lower() not in ["text", "json"]:
logger.warning(
"Warning: `%s` format is not supported in the ORCA response header",
format,
)
return None
header = {}
orca_report = {
"named_metrics": {
metric_name: value
for metric_name, value in named_metrics
if isinstance(metric_name, str) and isinstance(value, float)
}
}
# output example:
# endpoint-load-metrics: TEXT named_metrics.kv_cache_utilization=0.4
if metrics_format.lower() == "text":
native_http_header = ", ".join(
[
f"named_metrics.{metric_name}={value}"
for metric_name, value in named_metrics
if isinstance(metric_name, str) and isinstance(value, float)
]
)
header["endpoint-load-metrics"] = f"TEXT {native_http_header}"
# output example:
# endpoint-load-metrics: JSON “named_metrics”: {“custom-metric-util”: 0.4}
elif metrics_format.lower() == "json":
header["endpoint-load-metrics"] = f"JSON {json.dumps(orca_report)}"
logger.info("Created ORCA header %s", header)
return header
def get_named_metrics_from_prometheus() -> list[tuple[str, float]]:
"""
Collects current metrics from Prometheus and returns some of them
in the form of the `named_metrics` list for `create_orca_header()`.
Parameters:
- None
Returns:
- list[tuple[str, float]]: List of tuples of metric names and their values.
"""
named_metrics: list[tuple[str, float]] = []
# Map from prometheus metric names to ORCA named metrics.
prometheus_to_orca_metrics = {
"vllm:kv_cache_usage_perc": "kv_cache_usage_perc",
"vllm:num_requests_waiting": "num_requests_waiting",
}
metrics = get_metrics_snapshot()
for metric in metrics:
orca_name = prometheus_to_orca_metrics.get(metric.name)
# If this metric is mapped into ORCA, then add it to the report.
# Note: Only Gauge metrics are currently supported.
if orca_name is not None and isinstance(metric, Gauge):
named_metrics.append((str(orca_name), float(metric.value)))
return named_metrics
def metrics_header(metrics_format: str) -> Mapping[str, str] | None:
"""
Creates ORCA headers named 'endpoint-load-metrics' in the specified format.
Metrics are collected from Prometheus using `get_named_metrics_from_prometheus()`.
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
Parameters:
- metrics_format (str): The format of the header ('TEXT', 'JSON').
Returns:
- Optional[Mapping[str,str]]: A dictionary with header key as
'endpoint-load-metrics' and values as the ORCA header strings with
format prefix and data in with named_metrics in.
"""
if not metrics_format:
return None
# Get named metrics from prometheus.
named_metrics = get_named_metrics_from_prometheus()
return create_orca_header(metrics_format, named_metrics)

View File

@@ -0,0 +1,394 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import datetime
from collections.abc import Iterable, Sequence
from typing import Literal
from openai.types.responses.tool import Tool
from openai_harmony import (
Author,
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
StreamableParser,
SystemContent,
TextContent,
ToolDescription,
load_harmony_encoding,
)
from vllm import envs
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionToolsParam
from vllm.logger import init_logger
logger = init_logger(__name__)
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
_harmony_encoding = None
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
BUILTIN_TOOL_TO_MCP_SERVER_LABEL: dict[str, str] = {
"python": "code_interpreter",
"browser": "web_search_preview",
"container": "container",
}
# Derive MCP_BUILTIN_TOOLS from the canonical mapping
MCP_BUILTIN_TOOLS: set[str] = set(BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values())
def has_custom_tools(tool_types: set[str]) -> bool:
"""
Checks if the given tool types are custom tools
(i.e. any tool other than MCP buildin tools)
"""
return not tool_types.issubset(MCP_BUILTIN_TOOLS)
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def get_system_message(
model_identity: str | None = None,
reasoning_effort: Literal["high", "medium", "low"] | None = None,
start_date: str | None = None,
browser_description: str | None = None,
python_description: str | None = None,
container_description: str | None = None,
instructions: str | None = None,
with_custom_tools: bool = False,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
current_identity = sys_msg_content.model_identity
new_identity = (
f"{current_identity}\n{instructions}" if current_identity else instructions
)
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort]
)
if start_date is None:
# NOTE(woosuk): This brings non-determinism in vLLM.
# Set VLLM_SYSTEM_START_DATE to pin it.
start_date = envs.VLLM_SYSTEM_START_DATE or datetime.datetime.now().strftime(
"%Y-%m-%d"
)
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
if container_description is not None:
sys_msg_content = sys_msg_content.with_tools(container_description)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def create_tool_definition(tool: ChatCompletionToolsParam | Tool):
if isinstance(tool, ChatCompletionToolsParam):
return ToolDescription.new(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
)
return ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
def get_developer_message(
instructions: str | None = None,
tools: list[Tool | ChatCompletionToolsParam] | None = None,
) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools: list[Tool | ChatCompletionToolsParam] = []
for tool in tools:
if tool.type in (
"web_search_preview",
"code_interpreter",
"container",
):
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
create_tool_definition(tool) for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions
)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
"""
Parse a list of messages from request.messages in the Chat Completion API to
Harmony messages.
"""
msgs: list[Message] = []
tool_id_names: dict[str, str] = {}
# Collect tool id to name mappings for tool response recipient values
for chat_msg in chat_msgs:
for tool_call in chat_msg.get("tool_calls", []):
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
"name"
)
for chat_msg in chat_msgs:
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
msgs = auto_drop_analysis_messages(msgs)
return msgs
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
"""
Harmony models expect the analysis messages (representing raw chain of thought) to
be dropped after an assistant message to the final channel is produced from the
reasoning of those messages.
The openai-harmony library does this if the very last assistant message is to the
final channel, but it does not handle the case where we're in longer multi-turn
conversations and the client gave us reasoning content from previous turns of
the conversation with multiple assistant messages to the final channel in the
conversation.
So, we find the index of the last assistant message to the final channel and drop
all analysis messages that precede it, leaving only the analysis messages that
are relevant to the current part of the conversation.
"""
last_assistant_final_index = -1
for i in range(len(msgs) - 1, -1, -1):
msg = msgs[i]
if msg.author.role == "assistant" and msg.channel == "final":
last_assistant_final_index = i
break
cleaned_msgs: list[Message] = []
for i, msg in enumerate(msgs):
if i < last_assistant_final_index and msg.channel == "analysis":
continue
cleaned_msgs.append(msg)
return cleaned_msgs
def flatten_chat_text_content(content: str | list | None) -> str | None:
"""
Extract the text parts from a chat message content field and flatten them
into a single string.
"""
if isinstance(content, list):
return "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
return content
def parse_chat_input_to_harmony_message(
chat_msg, tool_id_names: dict[str, str] | None = None
) -> list[Message]:
"""
Parse a message from request.messages in the Chat Completion API to
Harmony messages.
"""
tool_id_names = tool_id_names or {}
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
msgs: list[Message] = []
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls", [])
if role == "assistant" and tool_calls:
content = flatten_chat_text_content(chat_msg.get("content"))
if content:
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
commentary_msg = commentary_msg.with_channel("commentary")
msgs.append(commentary_msg)
reasoning = chat_msg.get("reasoning")
if reasoning:
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
# Officially, this should be `<|constrain|>json` but there is not clear
# evidence that improves accuracy over `json` and some anecdotes to the
# contrary. Further testing of the different content_types is needed.
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
tool_call_id = chat_msg.get("tool_call_id", "")
name = tool_id_names.get(tool_call_id, "")
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)
msg = (
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
)
.with_channel("commentary")
.with_recipient("assistant")
)
return [msg]
# Non-tool reasoning content
reasoning = chat_msg.get("reasoning")
if role == "assistant" and reasoning:
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
# Default: user/assistant/system messages with content
content = chat_msg.get("content") or ""
if content is None:
content = ""
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
# Only add assistant messages if they have content, as reasoning or tool calling
# assistant messages were already added above.
if role == "assistant" and contents and contents[0].text:
msg = Message.from_role_and_contents(role, contents)
# Send non-tool assistant messages to the final channel
msg = msg.with_channel("final")
msgs.append(msg)
# For user/system/developer messages, add them directly even if no content.
elif role != "assistant":
msg = Message.from_role_and_contents(role, contents)
msgs.append(msg)
return msgs
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT
)
return token_ids
def get_stop_tokens_for_assistant_actions() -> list[int]:
return get_encoding().stop_tokens_for_assistant_actions()
def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
parser = get_streamable_parser_for_assistant()
for token_id in token_ids:
parser.process(token_id)
return parser
def parse_chat_output(
token_ids: Sequence[int],
) -> tuple[str | None, str | None, bool]:
"""
Parse the output of a Harmony chat completion into reasoning and final content.
Note that when the `openai` tool parser is used, serving_chat only uses this
for the reasoning content and gets the final content from the tool call parser.
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
in use,this needs to return the final content without any tool calls parsed.
Empty reasoning or final content is returned as None instead of an empty string.
"""
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported
# Get completed messages from the parser
# - analysis channel: hidden reasoning
# - commentary channel without recipient (preambles): visible to user
# - final channel: visible to user
# - commentary with recipient (tool calls): handled separately by tool parser
reasoning_texts = [
msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
]
final_texts = [
msg.content[0].text
for msg in output_msgs
if msg.channel == "final" or (msg.channel == "commentary" and not msg.recipient)
]
# Extract partial messages from the parser
if parser.current_channel == "analysis" and parser.current_content:
reasoning_texts.append(parser.current_content)
elif parser.current_channel == "final" and parser.current_content:
final_texts.append(parser.current_content)
elif (
parser.current_channel == "commentary"
and not parser.current_recipient
and parser.current_content
):
# Preambles (commentary without recipient) are visible to user
final_texts.append(parser.current_content)
# Flatten multiple messages into a single string
reasoning: str | None = "\n".join(reasoning_texts)
final_content: str | None = "\n".join(final_texts)
# Return None instead of empty string since existing callers check for None
reasoning = reasoning or None
final_content = final_content or None
return reasoning, final_content, is_tool_call

View File

@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections.abc import Callable
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
Content,
ResponseReasoningItem,
)
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid
logger = logging.getLogger(__name__)
class ResponsesParser:
"""Incremental parser over completion tokens with reasoning support."""
def __init__(
self,
*,
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
response_messages
)
self.num_init_messages = len(response_messages)
self.tokenizer = tokenizer
self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)
# Store the last finish_reason to determine response status
self.finish_reason: str | None = None
def process(self, output: CompletionOutput) -> "ResponsesParser":
# Store the finish_reason from the output
self.finish_reason = output.finish_reason
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
output.text, request=self.request
)
if reasoning_content:
self.response_messages.append(
ResponseReasoningItem(
type="reasoning",
id=f"rs_{random_uuid()}",
summary=[],
content=[
Content(
type="reasoning_text",
text=reasoning_content,
)
],
)
)
function_calls: list[ResponseFunctionToolCall] = []
if self.tool_parser_instance is not None:
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None
if content:
self.response_messages.append(
ResponseOutputMessage(
type="message",
id=f"msg_{random_uuid()}",
status="completed",
role="assistant",
content=[
ResponseOutputText(
annotations=[], # TODO
type="output_text",
text=content,
logprobs=None, # TODO
)
],
)
)
if len(function_calls) > 0:
self.response_messages.extend(function_calls)
return self
def make_response_output_items_from_parsable_context(
self,
) -> list[ResponseOutputItem]:
"""Given a list of sentences, construct ResponseOutput Items."""
response_messages = self.response_messages[self.num_init_messages :]
output_messages: list[ResponseOutputItem] = []
for message in response_messages:
if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message)
else:
if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages
def get_responses_parser_for_simple_context(
*,
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
) -> ResponsesParser:
"""Factory function to create a ResponsesParser with
optional reasoning parser.
Returns:
ResponsesParser instance configured with the provided parser
"""
return ResponsesParser(
tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import APIRouter, FastAPI, WebSocket
from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
router = APIRouter()
@router.websocket("/v1/realtime")
async def realtime_endpoint(websocket: WebSocket):
"""WebSocket endpoint for realtime audio transcription.
Protocol:
1. Client connects to ws://host/v1/realtime
2. Server sends session.created event
3. Client optionally sends session.update with model/params
4. Client sends input_audio_buffer.commit when ready
5. Client sends input_audio_buffer.append events with base64 PCM16 chunks
6. Server processes and sends transcription.delta events
7. Server sends transcription.done with final text + usage
8. Repeat from step 5 for next utterance
9. Optionally, client sends input_audio_buffer.commit with final=True
to signal audio input is finished. Useful when streaming audio files
Audio format: PCM16, 16kHz, mono, base64-encoded
"""
app = websocket.app
serving = app.state.openai_serving_realtime
connection = RealtimeConnection(websocket, serving)
await connection.handle_connection()
def attach_router(app: FastAPI):
"""Attach the realtime router to the FastAPI app."""
app.include_router(router)
logger.info("Realtime API router attached")
def init_realtime_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
state.openai_serving_realtime = (
OpenAIServingRealtime(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
)
if "realtime" in supported_tasks
else None
)

View File

@@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import json
from collections.abc import AsyncGenerator
from http import HTTPStatus
from uuid import uuid4
import numpy as np
from fastapi import WebSocket
from starlette.websockets import WebSocketDisconnect
from vllm import envs
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.realtime.protocol import (
ErrorEvent,
InputAudioBufferAppend,
InputAudioBufferCommit,
SessionCreated,
TranscriptionDelta,
TranscriptionDone,
)
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
logger = init_logger(__name__)
class RealtimeConnection:
"""Manages WebSocket lifecycle and state for realtime transcription.
This class handles:
- WebSocket connection lifecycle (accept, receive, send, close)
- Event routing (session.update, append, commit)
- Audio buffering via asyncio.Queue
- Generation task management
- Error handling and cleanup
"""
def __init__(self, websocket: WebSocket, serving: OpenAIServingRealtime):
self.websocket = websocket
self.connection_id = f"ws-{uuid4()}"
self.serving = serving
self.audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
self.generation_task: asyncio.Task | None = None
self._is_connected = False
self._is_model_validated = False
self._max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
async def handle_connection(self):
"""Main connection loop."""
await self.websocket.accept()
logger.debug("WebSocket connection accepted: %s", self.connection_id)
self._is_connected = True
# Send session created event
await self.send(SessionCreated())
try:
while True:
message = await self.websocket.receive_text()
try:
event = json.loads(message)
await self.handle_event(event)
except json.JSONDecodeError:
await self.send_error("Invalid JSON", "invalid_json")
except Exception as e:
logger.exception("Error handling event: %s", e)
await self.send_error(str(e), "processing_error")
except WebSocketDisconnect:
logger.debug("WebSocket disconnected: %s", self.connection_id)
self._is_connected = False
except Exception as e:
logger.exception("Unexpected error in connection: %s", e)
finally:
await self.cleanup()
def _check_model(self, model: str | None) -> None | ErrorResponse:
if self.serving._is_model_supported(model):
return None
return self.serving.create_error_response(
message=f"The model `{model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
async def handle_event(self, event: dict):
"""Route events to handlers.
Supported event types:
- session.update: Configure model
- input_audio_buffer.append: Add audio chunk to queue
- input_audio_buffer.commit: Start transcription generation
"""
event_type = event.get("type")
if event_type == "session.update":
logger.debug("Session updated: %s", event)
self._check_model(event["model"])
self._is_model_validated = True
elif event_type == "input_audio_buffer.append":
append_event = InputAudioBufferAppend(**event)
try:
audio_bytes = base64.b64decode(append_event.audio)
# Convert PCM16 bytes to float32 numpy array
audio_array = (
np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
/ 32768.0
)
if len(audio_array) / 1024**2 > self._max_audio_filesize_mb:
raise VLLMValidationError(
"Maximum file size exceeded",
parameter="audio_filesize_mb",
value=len(audio_array) / 1024**2,
)
if len(audio_array) == 0:
raise VLLMValidationError("Can't process empty audio.")
# Put audio chunk in queue
self.audio_queue.put_nowait(audio_array)
except Exception as e:
logger.error("Failed to decode audio: %s", e)
await self.send_error("Invalid audio data", "invalid_audio")
elif event_type == "input_audio_buffer.commit":
if not self._is_model_validated:
err_msg = (
"Model not validated. Make sure to validate the"
" model by sending a session.update event."
)
await self.send_error(
err_msg,
"model_not_validated",
)
commit_event = InputAudioBufferCommit(**event)
# final signals that the audio is finished
if commit_event.final:
self.audio_queue.put_nowait(None)
else:
await self.start_generation()
else:
await self.send_error(f"Unknown event type: {event_type}", "unknown_event")
async def audio_stream_generator(self) -> AsyncGenerator[np.ndarray, None]:
"""Generator that yields audio chunks from the queue."""
while True:
audio_chunk = await self.audio_queue.get()
if audio_chunk is None: # Sentinel value to stop
break
yield audio_chunk
async def start_generation(self):
"""Start the transcription generation task."""
if self.generation_task is not None and not self.generation_task.done():
logger.warning("Generation already in progress, ignoring commit")
return
# Create audio stream generator
audio_stream = self.audio_stream_generator()
input_stream = asyncio.Queue[list[int]]()
# Transform to StreamingInput generator
streaming_input_gen = self.serving.transcribe_realtime(
audio_stream, input_stream
)
# Start generation task
self.generation_task = asyncio.create_task(
self._run_generation(streaming_input_gen, input_stream)
)
async def _run_generation(
self,
streaming_input_gen: AsyncGenerator,
input_stream: asyncio.Queue[list[int]],
):
"""Run the generation and stream results back to the client.
This method:
1. Creates sampling parameters from session config
2. Passes the streaming input generator to engine.generate()
3. Streams transcription.delta events as text is generated
4. Sends final transcription.done event with usage stats
5. Feeds generated token IDs back to input_stream for next iteration
6. Cleans up the audio queue
"""
request_id = f"rt-{self.connection_id}-{uuid4()}"
full_text = ""
prompt_token_ids_len: int = 0
completion_tokens_len: int = 0
try:
# Create sampling params
from vllm.sampling_params import RequestOutputKind, SamplingParams
sampling_params = SamplingParams.from_optional(
temperature=0.0,
max_tokens=self.serving.model_cls.realtime_max_tokens,
output_kind=RequestOutputKind.DELTA,
skip_clone=True,
)
# Pass the streaming input generator to the engine
# The engine will consume audio chunks as they arrive and
# stream back transcription results incrementally
result_gen = self.serving.engine_client.generate(
prompt=streaming_input_gen,
sampling_params=sampling_params,
request_id=request_id,
)
# Stream results back to client as they're generated
async for output in result_gen:
if output.outputs and len(output.outputs) > 0:
if not prompt_token_ids_len and output.prompt_token_ids:
prompt_token_ids_len = len(output.prompt_token_ids)
delta = output.outputs[0].text
full_text += delta
# append output to input
input_stream.put_nowait(list(output.outputs[0].token_ids))
await self.send(TranscriptionDelta(delta=delta))
completion_tokens_len += len(output.outputs[0].token_ids)
if not self._is_connected:
# finish because websocket connection was killed
break
usage = UsageInfo(
prompt_tokens=prompt_token_ids_len,
completion_tokens=completion_tokens_len,
total_tokens=prompt_token_ids_len + completion_tokens_len,
)
# Send final completion event
await self.send(TranscriptionDone(text=full_text, usage=usage))
# Clear queue for next utterance
while not self.audio_queue.empty():
self.audio_queue.get_nowait()
except Exception as e:
logger.exception("Error in generation: %s", e)
await self.send_error(str(e), "processing_error")
async def send(
self, event: SessionCreated | TranscriptionDelta | TranscriptionDone
):
"""Send event to client."""
data = event.model_dump_json()
await self.websocket.send_text(data)
async def send_error(self, message: str, code: str | None = None):
"""Send error event to client."""
error_event = ErrorEvent(error=message, code=code)
await self.websocket.send_text(error_event.model_dump_json())
async def cleanup(self):
"""Cleanup resources."""
# Signal audio stream to stop
self.audio_queue.put_nowait(None)
# Cancel generation task if running
if self.generation_task and not self.generation_task.done():
self.generation_task.cancel()
logger.debug("Connection cleanup complete: %s", self.connection_id)

View File

@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Literal
from pydantic import Field
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
UsageInfo,
)
from vllm.utils import random_uuid
# Client -> Server Events
class InputAudioBufferAppend(OpenAIBaseModel):
"""Append audio chunk to buffer"""
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
audio: str # base64-encoded PCM16 @ 16kHz
class InputAudioBufferCommit(OpenAIBaseModel):
"""Process accumulated audio buffer"""
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
final: bool = False
# Server -> Client Events
class SessionUpdate(OpenAIBaseModel):
"""Configure session parameters"""
type: Literal["session.update"] = "session.update"
model: str | None = None
class SessionCreated(OpenAIBaseModel):
"""Connection established notification"""
type: Literal["session.created"] = "session.created"
id: str = Field(default_factory=lambda: f"sess-{random_uuid()}")
created: int = Field(default_factory=lambda: int(time.time()))
class TranscriptionDelta(OpenAIBaseModel):
"""Incremental transcription text"""
type: Literal["transcription.delta"] = "transcription.delta"
delta: str # Incremental text
class TranscriptionDone(OpenAIBaseModel):
"""Final transcription with usage stats"""
type: Literal["transcription.done"] = "transcription.done"
text: str # Complete transcription
usage: UsageInfo | None = None
class ErrorEvent(OpenAIBaseModel):
"""Error notification"""
type: Literal["error"] = "error"
error: str
code: str | None = None

View File

@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Literal, cast
import numpy as np
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime
from vllm.renderers.inputs.preprocess import parse_model_prompt
logger = init_logger(__name__)
class OpenAIServingRealtime(OpenAIServing):
"""Realtime audio transcription service via WebSocket streaming.
Provides streaming audio-to-text transcription by transforming audio chunks
into StreamingInput objects that can be consumed by the engine.
"""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.task_type: Literal["realtime"] = "realtime"
logger.info("OpenAIServingRealtime initialized for task: %s", self.task_type)
@cached_property
def model_cls(self) -> type[SupportsRealtime]:
"""Get the model class that supports transcription."""
from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsRealtime], model_cls)
async def transcribe_realtime(
self,
audio_stream: AsyncGenerator[np.ndarray, None],
input_stream: asyncio.Queue[list[int]],
) -> AsyncGenerator[StreamingInput, None]:
"""Transform audio stream into StreamingInput for engine.generate().
Args:
audio_stream: Async generator yielding float32 numpy audio arrays
input_stream: Queue containing context token IDs from previous
generation outputs. Used for autoregressive multi-turn
processing where each generation's output becomes the context
for the next iteration.
Yields:
StreamingInput objects containing audio prompts for the engine
"""
model_config = self.model_config
renderer = self.renderer
# mypy is being stupid
# TODO(Patrick) - fix this
stream_input_iter = cast(
AsyncGenerator[PromptType, None],
self.model_cls.buffer_realtime_audio(
audio_stream, input_stream, model_config
),
)
async for prompt in stream_input_iter:
parsed_prompt = parse_model_prompt(model_config, prompt)
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
yield StreamingInput(prompt=engine_prompt)

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import AsyncGenerator
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
ResponsesResponse,
StreamingResponsesResponse,
)
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def responses(request: Request) -> OpenAIServingResponses | None:
return request.app.state.openai_serving_responses
async def _convert_stream_to_sse_events(
generator: AsyncGenerator[StreamingResponsesResponse, None],
) -> AsyncGenerator[str, None]:
"""Convert the generator to a stream of events in SSE format"""
async for event in generator:
event_type = getattr(event, "type", "unknown")
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
event_data = (
f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n"
)
yield event_data
@router.post(
"/v1/responses",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_responses(request: ResponsesRequest, raw_request: Request):
handler = responses(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Responses API"
)
try:
generator = await handler.create_responses(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ResponsesResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(
content=_convert_stream_to_sse_events(generator), media_type="text/event-stream"
)
@router.get("/v1/responses/{response_id}")
@load_aware_call
async def retrieve_responses(
response_id: str,
raw_request: Request,
starting_after: int | None = None,
stream: bool | None = False,
):
handler = responses(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Responses API"
)
try:
response = await handler.retrieve_responses(
response_id,
starting_after=starting_after,
stream=stream,
)
except Exception as e:
response = handler.create_error_response(e)
if isinstance(response, ErrorResponse):
return JSONResponse(
content=response.model_dump(), status_code=response.error.code
)
elif isinstance(response, ResponsesResponse):
return JSONResponse(content=response.model_dump())
return StreamingResponse(
content=_convert_stream_to_sse_events(response), media_type="text/event-stream"
)
@router.post("/v1/responses/{response_id}/cancel")
@load_aware_call
async def cancel_responses(response_id: str, raw_request: Request):
handler = responses(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Responses API"
)
try:
response = await handler.cancel_responses(response_id)
except Exception as e:
response = handler.create_error_response(e)
if isinstance(response, ErrorResponse):
return JSONResponse(
content=response.model_dump(), status_code=response.error.code
)
return JSONResponse(content=response.model_dump())
def attach_router(app: FastAPI):
app.include_router(router)

View File

@@ -0,0 +1,918 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import copy
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Final, Union
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.mcp.tool import Tool
from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
)
from vllm.entrypoints.openai.parser.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
)
from vllm.entrypoints.openai.parser.responses_parser import (
get_responses_parser_for_simple_context,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem,
ResponseRawMessageAndToken,
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid
if TYPE_CHECKING:
from mcp.client import ClientSession
logger = logging.getLogger(__name__)
# This is currently needed as the tool type doesn't 1:1 match the
# tool namespace, which is what is used to look up the
# connection to the tool server
_TOOL_NAME_TO_TYPE_MAP = {
"browser": "web_search_preview",
"python": "code_interpreter",
"container": "container",
}
def _map_tool_name_to_tool_type(tool_name: str) -> str:
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys())
raise ValueError(
f"Built-in tool name '{tool_name}' not defined in mapping. "
f"Available tools: {available_tools}"
)
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
class TurnMetrics:
"""Tracks token and toolcall details for a single conversation turn."""
def __init__(
self,
input_tokens: int = 0,
output_tokens: int = 0,
cached_input_tokens: int = 0,
tool_output_tokens: int = 0,
) -> None:
self.input_tokens = input_tokens
self.output_tokens = output_tokens
self.cached_input_tokens = cached_input_tokens
self.tool_output_tokens = tool_output_tokens
def reset(self) -> None:
"""Reset counters for a new turn."""
self.input_tokens = 0
self.output_tokens = 0
self.cached_input_tokens = 0
self.tool_output_tokens = 0
def copy(self) -> "TurnMetrics":
"""Create a copy of this turn's token counts."""
return TurnMetrics(
self.input_tokens,
self.output_tokens,
self.cached_input_tokens,
self.tool_output_tokens,
)
class ConversationContext(ABC):
@abstractmethod
def append_output(self, output: RequestOutput) -> None:
pass
@abstractmethod
def append_tool_output(self, output) -> None:
pass
@abstractmethod
async def call_tool(self) -> list[Message]:
pass
@abstractmethod
def need_builtin_tool_call(self) -> bool:
pass
@abstractmethod
def render_for_completion(self) -> list[int]:
pass
@abstractmethod
async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
) -> None:
pass
@abstractmethod
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
def _create_json_parse_error_messages(
last_msg: Message, e: json.JSONDecodeError
) -> list[Message]:
"""
Creates an error message when json parse failed.
"""
error_msg = (
f"Error parsing tool arguments as JSON: {str(e)}. "
"Please ensure the tool call arguments are valid JSON and try again."
)
content = TextContent(text=error_msg)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(
author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel,
)
]
class SimpleContext(ConversationContext):
"""This is a context that cannot handle MCP tool calls"""
def __init__(self):
self.last_output = None
# Accumulated final output for streaming mode
self._accumulated_text: str = ""
self._accumulated_token_ids: list[int] = []
self._accumulated_logprobs: list = []
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
# todo num_reasoning_tokens is not implemented yet.
self.num_reasoning_tokens = 0
# not implemented yet for SimpleContext
self.all_turn_metrics = []
self.input_messages: list[ResponseRawMessageAndToken] = []
def append_output(self, output) -> None:
self.last_output = output
if not isinstance(output, RequestOutput):
raise ValueError("SimpleContext only supports RequestOutput.")
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
# Accumulate text, token_ids, and logprobs for streaming mode
delta_output = output.outputs[0]
self._accumulated_text += delta_output.text
self._accumulated_token_ids.extend(delta_output.token_ids)
if delta_output.logprobs is not None:
self._accumulated_logprobs.extend(delta_output.logprobs)
if len(self.input_messages) == 0:
output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or []
self.input_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
@property
def output_messages(self) -> list[ResponseRawMessageAndToken]:
"""Return consolidated output as a single message.
In streaming mode, text and tokens are accumulated across many deltas.
This property returns them as a single entry rather than one per delta.
"""
if not self._accumulated_text and not self._accumulated_token_ids:
return []
return [
ResponseRawMessageAndToken(
message=self._accumulated_text,
tokens=list(self._accumulated_token_ids),
)
]
@property
def final_output(self) -> RequestOutput | None:
"""Return the final output, with complete text/token_ids/logprobs."""
if self.last_output is not None and self.last_output.outputs:
assert isinstance(self.last_output, RequestOutput)
final_output = copy.copy(self.last_output)
# copy inner item to avoid modify last_output
final_output.outputs = [replace(item) for item in self.last_output.outputs]
final_output.outputs[0].text = self._accumulated_text
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
if self._accumulated_logprobs:
final_output.outputs[0].logprobs = self._accumulated_logprobs
return final_output
return self.last_output
def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")
def need_builtin_tool_call(self) -> bool:
return False
async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
) -> None:
pass
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class ParsableContext(ConversationContext):
def __init__(
self,
*,
response_messages: list[ResponseInputOutputItem],
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0
# not implemented yet for ParsableContext
self.all_turn_metrics: list[TurnMetrics] = []
if reasoning_parser_cls is None:
raise ValueError("reasoning_parser_cls must be provided.")
self.parser = get_responses_parser_for_simple_context(
tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []
self._accumulated_token_ids: list[int] = []
def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
self.parser.process(output.outputs[0])
output_token_ids = output.outputs[0].token_ids or []
self._accumulated_token_ids.extend(output_token_ids)
# only store if enable_response_messages is True, save memory
if self.request.enable_response_messages:
output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or []
if len(self.input_messages) == 0:
self.input_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
else:
self.output_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
self.output_messages.append(
ResponseRawMessageAndToken(
message=output.outputs[0].text,
tokens=output.outputs[0].token_ids,
)
)
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
self.parser.response_messages.extend(output)
def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a builtin tool call
that the request has enabled."""
last_message = self.parser.response_messages[-1]
if last_message.type != "function_call":
return False
if last_message.name in ("code_interpreter", "python"):
return "python" in self.available_tools
if last_message.name == "web_search_preview":
return "browser" in self.available_tools
if last_message.name.startswith("container"):
return "container" in self.available_tools
return False
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[ResponseInputOutputItem]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result_parsable_context(self)
args = json.loads(last_msg.arguments)
param = {
"code": args["code"],
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"mcpo_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
status="completed",
)
return [message]
async def call_search_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[ResponseInputOutputItem]:
self.called_tools.add("browser")
if isinstance(tool_session, Tool):
return await tool_session.get_result_parsable_context(self)
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.arguments)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.arguments)
result = await tool_session.call_tool("search", args)
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
status="completed",
)
return [message]
async def call_container_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
"""
Call container tool. Expect this to be run in a stateful docker
with command line terminal.
The official container tool would at least
expect the following format:
- for tool name: exec
- args:
{
"cmd":List[str] "command to execute",
"workdir":optional[str] "current working directory",
"env":optional[object/dict] "environment variables",
"session_name":optional[str] "session name",
"timeout":optional[int] "timeout in seconds",
"user":optional[str] "user name",
}
"""
self.called_tools.add("container")
if isinstance(tool_session, Tool):
return await tool_session.get_result_parsable_context(self)
# tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.arguments)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.arguments)
result = await tool_session.call_tool("exec", args)
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
status="completed",
)
return [message]
async def call_tool(self) -> list[ResponseInputOutputItem]:
if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
# change this to a mcp_ function call
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
self.parser.response_messages[-1] = last_msg
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
elif last_msg.name == "web_search_preview":
return await self.call_search_tool(self._tool_sessions["browser"], last_msg)
elif last_msg.name.startswith("container"):
return await self.call_container_tool(
self._tool_sessions["container"], last_msg
)
return []
def render_for_completion(self):
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
):
if tool_server:
for tool_name in self.available_tools:
if tool_name in self._tool_sessions:
continue
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
available_tools: list[str],
):
self._messages = messages
self.finish_reason: str | None = None
self.available_tools = available_tools
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0
self.num_tool_output_tokens = 0
# Turn tracking - replaces multiple individual tracking variables
self.current_turn_metrics = TurnMetrics()
# Track metrics for all turns
self.all_turn_metrics: list[TurnMetrics] = []
self.is_first_turn = True
self.first_tok_of_message = True # For streaming support
def _update_num_reasoning_tokens(self):
channel = self.parser.current_channel
if channel == "analysis":
self.num_reasoning_tokens += 1
elif channel == "commentary" and self.parser.current_recipient is not None:
# Tool interactions (python/browser/container) are hidden.
# Preambles (recipient=None) are visible user text.
self.num_reasoning_tokens += 1
def append_output(self, output: RequestOutput) -> None:
output_token_ids = output.outputs[0].token_ids
self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output)
self._update_decode_token_usage(output)
# Append current turn to all turn list for next turn's calculations
self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset()
# append_output is called only once before tool calling
# in non-streaming case
# so we can append all the parser messages to _messages
output_msgs = self.parser.messages
# The responses finish reason is set in the last message
self.finish_reason = output.outputs[0].finish_reason
self._messages.extend(output_msgs)
def append_tool_output(self, output: list[Message]) -> None:
output_msgs = output
self._messages.extend(output_msgs)
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
"""Update token usage statistics for the prefill phase of generation.
The prefill phase processes the input prompt tokens. This method:
1. Counts the prompt tokens for this turn
2. Calculates tool output tokens for multi-turn conversations
3. Updates cached token counts
4. Tracks state for next turn calculations
Tool output tokens are calculated as:
current_prompt_tokens - last_turn_prompt_tokens -
last_turn_output_tokens
This represents tokens added between turns (typically tool responses).
Args:
output: The RequestOutput containing prompt token information
"""
if output.prompt_token_ids is not None:
this_turn_input_tokens = len(output.prompt_token_ids)
else:
this_turn_input_tokens = 0
logger.error("RequestOutput appended contains no prompt_token_ids.")
# Update current turn input tokens
self.current_turn_metrics.input_tokens = this_turn_input_tokens
self.num_prompt_tokens += this_turn_input_tokens
# Calculate tool tokens (except on first turn)
if self.is_first_turn:
self.is_first_turn = False
else:
previous_turn = self.all_turn_metrics[-1]
# start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill -
# last turn decode
this_turn_tool_tokens = (
self.current_turn_metrics.input_tokens
- previous_turn.input_tokens
- previous_turn.output_tokens
)
# Handle negative tool token counts (shouldn't happen in normal
# cases)
if this_turn_tool_tokens < 0:
logger.error(
"Negative tool output tokens calculated: %d "
"(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0.",
this_turn_tool_tokens,
self.current_turn_metrics.input_tokens,
previous_turn.input_tokens,
previous_turn.output_tokens,
)
this_turn_tool_tokens = 0
self.num_tool_output_tokens += this_turn_tool_tokens
self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens
# Update cached tokens
num_cached_token = output.num_cached_tokens
if num_cached_token is not None:
self.num_cached_tokens += num_cached_token
self.current_turn_metrics.cached_input_tokens = num_cached_token
def _update_decode_token_usage(self, output: RequestOutput) -> int:
"""Update token usage statistics for the decode phase of generation.
The decode phase processes the generated output tokens. This method:
1. Counts output tokens from all completion outputs
2. Updates the total output token count
3. Tracks tokens generated in the current turn
In streaming mode, this is called for each token generated.
In non-streaming mode, this is called once with all output tokens.
Args:
output: The RequestOutput containing generated token information
Returns:
int: Number of output tokens processed in this call
"""
updated_output_token_count = 0
if output.outputs:
for completion_output in output.outputs:
# only keep last round
updated_output_token_count += len(completion_output.token_ids)
self.num_output_tokens += updated_output_token_count
self.current_turn_metrics.output_tokens += updated_output_token_count
return updated_output_token_count
@property
def messages(self) -> list:
return self._messages
def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is None:
return False
if recipient.startswith("browser."):
return "browser" in self.available_tools
if recipient.startswith("python"):
return "python" in self.available_tools
if recipient.startswith("container."):
return "container" in self.available_tools
return False
async def call_tool(self) -> list[Message]:
if not self.messages:
return []
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self._tool_sessions["browser"], last_msg
)
elif recipient.startswith("python"):
return await self.call_python_tool(
self._tool_sessions["python"], last_msg
)
elif recipient.startswith("container."):
return await self.call_container_tool(
self._tool_sessions["container"], last_msg
)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)
async def call_search_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
self.called_tools.add("browser")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(
author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel,
)
]
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
"code": last_msg.content[0].text,
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name="python")
return [
Message(
author=author,
content=[content],
channel=last_msg.channel,
recipient=Role.ASSISTANT,
)
]
async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def call_container_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
"""
Call container tool. Expect this to be run in a stateful docker
with command line terminal.
The official container tool would at least
expect the following format:
- for tool name: exec
- args:
{
"cmd":List[str] "command to execute",
"workdir":optional[str] "current working directory",
"env":optional[object/dict] "environment variables",
"session_name":optional[str] "session name",
"timeout":optional[int] "timeout in seconds",
"user":optional[str] "user name",
}
"""
self.called_tools.add("container")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(
author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel,
)
]
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)
class StreamingHarmonyContext(HarmonyContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_output = None
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
self.first_tok_of_message = True
self.last_content_delta = None
@property
def messages(self) -> list:
return self._messages
def append_output(self, output: RequestOutput) -> None:
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
self.last_content_delta = None
if self.first_tok_of_message:
self._update_prefill_token_usage(output)
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
# beginning of a new message
self.first_tok_of_message = output.finished
last_delta_text = ""
for tok in output.outputs[0].token_ids:
self.parser.process(tok)
last_delta_text += self.parser.last_content_delta or ""
if last_delta_text:
self.last_content_delta = last_delta_text
self._update_decode_token_usage(output)
# For streaming, update previous turn when message is complete
if output.finished:
self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset()
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self.last_tok = tok
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
self._messages.extend(
self.parser.messages[len(self._messages) - self.num_init_messages :]
)
def append_tool_output(self, output: list[Message]) -> None:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
msg = output[0]
# Sometimes the recipient is not set for tool messages,
# so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant"
toks = self.encoding.render(msg)
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
# TODO: add tool_output messages to self._messages
def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START
def is_assistant_action_turn(self) -> bool:
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
def render_for_completion(self) -> list[int]:
# now this list of tokens as next turn's starting tokens
# `<|start|>assistant`,
# we need to process them in parser.
rendered_tokens = super().render_for_completion()
last_n = -1
to_process = []
while rendered_tokens[last_n] != self.last_tok:
to_process.append(rendered_tokens[last_n])
last_n -= 1
for tok in reversed(to_process):
self.parser.process(tok)
return rendered_tokens

View File

@@ -0,0 +1,552 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Harmony ↔ Responses API conversion utilities.
Handles two directions:
1. Response Input → Harmony Messages (input parsing)
2. Harmony Messages → Response Output Items (output parsing)
"""
import json
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseOutputItem,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
)
from openai.types.responses.response_function_web_search import (
ActionFind,
ActionOpenPage,
ActionSearch,
ResponseFunctionWebSearch,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai_harmony import Author, Message, Role, StreamableParser, TextContent
from vllm.entrypoints.openai.parser.harmony_utils import (
BUILTIN_TOOL_TO_MCP_SERVER_LABEL,
flatten_chat_text_content,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# 1. Private helpers for input parsing
# ---------------------------------------------------------------------------
def _parse_harmony_format_message(chat_msg: dict) -> Message:
"""Reconstruct a Message from Harmony-format dict,
preserving channel, recipient, and content_type."""
author_dict = chat_msg["author"]
role = author_dict.get("role")
name = author_dict.get("name")
raw_content = chat_msg.get("content", "")
if isinstance(raw_content, list):
# TODO: Support refusal and non-text content types.
contents = [TextContent(text=c.get("text", "")) for c in raw_content]
elif isinstance(raw_content, str):
contents = [TextContent(text=raw_content)]
else:
contents = [TextContent(text="")]
if name:
msg = Message.from_author_and_contents(Author.new(Role(role), name), contents)
else:
msg = Message.from_role_and_contents(Role(role), contents)
channel = chat_msg.get("channel")
if channel:
msg = msg.with_channel(channel)
recipient = chat_msg.get("recipient")
if recipient:
msg = msg.with_recipient(recipient)
content_type = chat_msg.get("content_type")
if content_type:
msg = msg.with_content_type(content_type)
return msg
def _parse_chat_format_message(chat_msg: dict) -> list[Message]:
"""Parse an OpenAI chat-format dict into Harmony messages."""
role = chat_msg.get("role")
if role is None:
raise ValueError(f"Message has no 'role' key: {chat_msg}")
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls")
if role == "assistant" and tool_calls:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
if name and not name.startswith("functions."):
name = f"functions.{name}"
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)
# NOTE: .with_recipient("assistant") is required on tool messages
# to match parse_chat_input_to_harmony_message behavior and ensure
# proper routing in the Harmony protocol.
msg = (
Message.from_author_and_content(Author.new(Role.TOOL, name), content)
.with_channel("commentary")
.with_recipient("assistant")
)
return [msg]
# Default: user/assistant/system messages
content = chat_msg.get("content", "")
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
msg = Message.from_role_and_contents(role, contents)
return [msg]
# ---------------------------------------------------------------------------
# 2. Public input parsing functions
# ---------------------------------------------------------------------------
def response_input_to_harmony(
response_msg: ResponseInputOutputItem,
prev_responses: list[ResponseOutputItem | ResponseReasoningItem],
) -> Message:
"""Convert a single ResponseInputOutputItem into a Harmony Message."""
if not isinstance(response_msg, dict):
response_msg = response_msg.model_dump()
if "type" not in response_msg or response_msg["type"] == "message":
role = response_msg["role"]
content = response_msg["content"]
# Add prefix for developer messages.
# <|start|>developer<|message|># Instructions {instructions}<|end|>
text_prefix = "Instructions:\n" if role == "developer" else ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
else:
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
msg = Message.from_role_and_contents(role, contents)
if role == "assistant":
msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: ResponseFunctionToolCall | None = None
for prev_response in reversed(prev_responses):
if (
isinstance(prev_response, ResponseFunctionToolCall)
and prev_response.call_id == call_id
):
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
response_msg["output"],
)
elif response_msg["type"] == "reasoning":
content = response_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif response_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{response_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {response_msg['type']}")
return msg
def response_previous_input_to_harmony(chat_msg) -> list[Message]:
"""Parse a message from request.previous_input_messages
into Harmony messages.
Supports both OpenAI chat format ({"role": "..."}) and
Harmony format ({"author": {"role": "..."}}).
"""
if not isinstance(chat_msg, dict):
chat_msg = chat_msg.model_dump(exclude_none=True)
if "author" in chat_msg and isinstance(chat_msg.get("author"), dict):
return [_parse_harmony_format_message(chat_msg)]
return _parse_chat_format_message(chat_msg)
def construct_harmony_previous_input_messages(
request: ResponsesRequest,
) -> list[Message]:
"""Build a Harmony message list from request.previous_input_messages.
Filters out system/developer messages to match OpenAI behavior where
instructions are always taken from the most recent Responses API request.
"""
messages: list[Message] = []
if request.previous_input_messages:
for message in request.previous_input_messages:
# Handle both Message objects and dictionary inputs
if isinstance(message, Message):
message_role = message.author.role
if message_role == Role.SYSTEM or message_role == Role.DEVELOPER:
continue
messages.append(message)
else:
harmony_messages = response_previous_input_to_harmony(message)
for harmony_msg in harmony_messages:
message_role = harmony_msg.author.role
if message_role == Role.SYSTEM or message_role == Role.DEVELOPER:
continue
messages.append(harmony_msg)
return messages
# ---------------------------------------------------------------------------
# 3. Private helpers for output parsing
# ---------------------------------------------------------------------------
def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem:
"""Parse browser tool calls (search, open, find) into web search items."""
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
# Parse JSON args (with retry detection)
try:
browser_call = json.loads(content.text)
except json.JSONDecodeError:
logger.warning(
"Invalid JSON in browser tool call, using error placeholder: %s",
content.text,
)
json_retry_output_message = (
f"Invalid JSON args, caught and retried: {content.text}"
)
browser_call = {
"query": json_retry_output_message,
"url": json_retry_output_message,
"pattern": json_retry_output_message,
}
# Create appropriate action based on recipient
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search"
)
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
)
elif recipient == "browser.find":
action = ActionFind(
pattern=browser_call.get("pattern", ""),
url=f"cursor:{browser_call.get('url', '')}",
type="find",
)
else:
raise ValueError(f"Unknown browser action: {recipient}")
return ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
"""Parse function calls into function tool call items."""
function_name = recipient.split(".")[-1]
output_items = []
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
return output_items
def _parse_reasoning(message: Message) -> list[ResponseOutputItem]:
"""Parse reasoning/analysis content into reasoning items."""
output_items = []
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(text=content.text, type="reasoning_text")
],
status=None,
)
output_items.append(reasoning_item)
return output_items
def _parse_final_message(message: Message) -> ResponseOutputItem:
"""Parse final channel messages into output message items."""
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
return ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
def _parse_mcp_recipient(recipient: str) -> tuple[str, str]:
"""Parse MCP recipient into (server_label, tool_name).
For dotted recipients like "repo_browser.list":
- server_label: "repo_browser" (namespace/server)
- tool_name: "list" (specific tool)
For simple recipients like "filesystem":
- server_label: "filesystem"
- tool_name: "filesystem"
"""
if "." in recipient:
server_label = recipient.split(".")[0]
tool_name = recipient.split(".")[-1]
else:
server_label = recipient
tool_name = recipient
return server_label, tool_name
def _parse_mcp_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
"""Parse MCP calls into MCP call items."""
# Handle built-in tools that need server_label mapping
if recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
server_label = BUILTIN_TOOL_TO_MCP_SERVER_LABEL[recipient]
tool_name = recipient
else:
server_label, tool_name = _parse_mcp_recipient(recipient)
output_items = []
for content in message.content:
response_item = McpCall(
arguments=content.text,
type="mcp_call",
name=tool_name,
server_label=server_label,
id=f"mcp_{random_uuid()}",
status="completed",
)
output_items.append(response_item)
return output_items
def _parse_message_no_recipient(
message: Message,
) -> list[ResponseOutputItem]:
"""Parse a Harmony message with no recipient based on its channel."""
if message.channel == "analysis":
return _parse_reasoning(message)
if message.channel in ("commentary", "final"):
# Per Harmony format, preambles (commentary with no recipient) and
# final channel content are both intended to be shown to end-users.
# See: https://cookbook.openai.com/articles/openai-harmony
return [_parse_final_message(message)]
raise ValueError(f"Unknown channel: {message.channel}")
# ---------------------------------------------------------------------------
# 4. Public output parsing functions
# ---------------------------------------------------------------------------
def harmony_to_response_output(message: Message) -> list[ResponseOutputItem]:
"""Parse a Harmony message into a list of output response items.
This is the main dispatcher that routes based on channel and recipient.
"""
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items: list[ResponseOutputItem] = []
recipient = message.recipient
if recipient is not None:
# Browser tool calls (browser.search, browser.open, browser.find)
if recipient.startswith("browser."):
output_items.append(_parse_browser_tool_call(message, recipient))
# Function calls (should only happen on commentary channel)
elif message.channel == "commentary" and recipient.startswith("functions."):
output_items.extend(_parse_function_call(message, recipient))
# Built-in MCP tools (python, browser, container)
elif recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
output_items.extend(_parse_reasoning(message))
# All other recipients are MCP calls
else:
output_items.extend(_parse_mcp_call(message, recipient))
# No recipient - handle based on channel for non-tool messages
else:
output_items.extend(_parse_message_no_recipient(message))
return output_items
def parser_state_to_response_output(
parser: StreamableParser,
) -> list[ResponseOutputItem]:
"""Extract in-progress response items from incomplete parser state.
Called when the parser has buffered content that hasn't formed a
complete message yet (e.g., generation was cut short).
"""
if not parser.current_content:
return []
if parser.current_role != Role.ASSISTANT:
return []
current_recipient = parser.current_recipient
if current_recipient is not None and current_recipient.startswith("browser."):
return []
if current_recipient and parser.current_channel in ("commentary", "analysis"):
if current_recipient.startswith("functions."):
rid = random_uuid()
return [
ResponseFunctionToolCall(
arguments=parser.current_content,
call_id=f"call_{rid}",
type="function_call",
name=current_recipient.split(".")[-1],
id=f"fc_{rid}",
status="in_progress",
)
]
# Built-in MCP tools (python, browser, container)
elif current_recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
return [
ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
]
# All other recipients are MCP calls
else:
rid = random_uuid()
server_label, tool_name = _parse_mcp_recipient(current_recipient)
return [
McpCall(
arguments=parser.current_content,
type="mcp_call",
name=tool_name,
server_label=server_label,
id=f"mcp_{rid}",
status="in_progress",
)
]
if parser.current_channel == "commentary":
# Per Harmony format, preambles (commentary with no recipient) are
# intended to be shown to end-users, unlike analysis channel content.
output_text = ResponseOutputText(
text=parser.current_content,
annotations=[],
type="output_text",
logprobs=None,
)
return [
ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
status="incomplete",
type="message",
)
]
if parser.current_channel == "analysis":
return [
ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
]
if parser.current_channel == "final":
output_text = ResponseOutputText(
text=parser.current_content,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
# if the parser still has messages (ie if the generator got cut
# abruptly), this should be incomplete
status="incomplete",
type="message",
)
return [text_item]
return []

View File

@@ -0,0 +1,641 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Any, Literal, TypeAlias
import torch
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseFunctionToolCall,
ResponseInputItemParam,
ResponseMcpCallArgumentsDeltaEvent,
ResponseMcpCallArgumentsDoneEvent,
ResponseMcpCallCompletedEvent,
ResponseMcpCallInProgressEvent,
ResponseOutputItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponsePrompt,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseStatus,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses import (
ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
from openai.types.responses import (
ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
from openai.types.responses.tool import Tool
from openai_harmony import Message as OpenAIHarmonyMessage
# Backward compatibility for OpenAI client versions
try: # For older openai versions (< 1.100.0)
from openai.types.responses import ResponseTextConfig
except ImportError: # For newer openai versions (>= 1.100.0)
from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
from openai.types.responses.response import IncompleteDetails, ToolChoice
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai.types.shared import Metadata, Reasoning
from pydantic import (
Field,
ValidationError,
field_serializer,
model_validator,
)
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
)
from vllm.utils import random_uuid
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
class InputTokensDetails(OpenAIBaseModel):
cached_tokens: int
input_tokens_per_turn: list[int] = Field(default_factory=list)
cached_tokens_per_turn: list[int] = Field(default_factory=list)
class OutputTokensDetails(OpenAIBaseModel):
reasoning_tokens: int = 0
tool_output_tokens: int = 0
output_tokens_per_turn: list[int] = Field(default_factory=list)
tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
class ResponseUsage(OpenAIBaseModel):
input_tokens: int
input_tokens_details: InputTokensDetails
output_tokens: int
output_tokens_details: OutputTokensDetails
total_tokens: int
def serialize_message(msg):
"""
Serializes a single message
"""
if isinstance(msg, dict):
return msg
elif hasattr(msg, "to_dict"):
return msg.to_dict()
else:
# fallback to pyandic dump
return msg.model_dump_json()
def serialize_messages(msgs):
"""
Serializes multiple messages
"""
return [serialize_message(msg) for msg in msgs] if msgs else None
class ResponseRawMessageAndToken(OpenAIBaseModel):
"""Class to show the raw message.
If message / tokens diverge, tokens is the source of truth"""
message: str
tokens: list[int]
type: Literal["raw_message_tokens"] = "raw_message_tokens"
ResponseInputOutputMessage: TypeAlias = (
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
)
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
class ResponsesRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/responses/create
background: bool | None = False
include: (
list[
Literal[
"code_interpreter_call.outputs",
"computer_call_output.output.image_url",
"file_search_call.results",
"message.input_image.image_url",
"message.output_text.logprobs",
"reasoning.encrypted_content",
],
]
| None
) = None
input: str | list[ResponseInputOutputItem]
instructions: str | None = None
max_output_tokens: int | None = None
max_tool_calls: int | None = None
metadata: Metadata | None = None
model: str | None = None
logit_bias: dict[str, float] | None = None
parallel_tool_calls: bool | None = True
previous_response_id: str | None = None
prompt: ResponsePrompt | None = None
reasoning: Reasoning | None = None
service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
store: bool | None = True
stream: bool | None = False
temperature: float | None = None
text: ResponseTextConfig | None = None
tool_choice: ToolChoice = "auto"
tools: list[Tool] = Field(default_factory=list)
top_logprobs: int | None = 0
top_p: float | None = None
top_k: int | None = None
truncation: Literal["auto", "disabled"] | None = "disabled"
user: str | None = None
skip_special_tokens: bool = True
include_stop_str_in_output: bool = False
prompt_cache_key: str | None = Field(
default=None,
description=(
"A key that was used to read from or write to the prompt cache."
"Note: This field has not been implemented yet "
"and vLLM will ignore it."
),
)
# --8<-- [start:responses-extra-params]
request_id: str = Field(
default_factory=lambda: f"resp_{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
enable_response_messages: bool = Field(
default=False,
description=(
"Dictates whether or not to return messages as part of the "
"response object. Currently only supported for non-background."
),
)
# similar to input_messages / output_messages in ResponsesResponse
# we take in previous_input_messages (ie in harmony format)
# this cannot be used in conjunction with previous_response_id
# TODO: consider supporting non harmony messages as well
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
structured_outputs: StructuredOutputsParams | None = Field(
default=None,
description="Additional kwargs for structured outputs",
)
repetition_penalty: float | None = None
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: str | list[str] | None = []
ignore_eos: bool = False
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None,
description=(
"Additional request parameters with (list of) string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:responses-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
from .utils import should_continue_final_message
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(self.input)
reasoning = self.reasoning
return ChatParams(
chat_template=default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs( # To remove unset values
{},
dict(
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
reasoning_effort=None if reasoning is None else reasoning.effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_output_tokens or 0,
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_output_tokens",
)
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
}
def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: dict | None = None,
) -> SamplingParams:
if self.max_output_tokens is None:
max_tokens = default_max_tokens
else:
max_tokens = min(self.max_output_tokens, default_max_tokens)
default_sampling_params = default_sampling_params or {}
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get("repetition_penalty", 1.0)
stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output
structured_outputs = self.structured_outputs
# Also check text.format for OpenAI-style json_schema
if self.text is not None and self.text.format is not None:
if structured_outputs is not None:
raise ValueError(
"Cannot specify both structured_outputs and text.format"
)
response_format = self.text.format
if (
response_format.type == "json_schema"
and response_format.schema_ is not None
):
structured_outputs = StructuredOutputsParams(
json=response_format.schema_ # type: ignore[call-arg]
# --follow-imports skip hides the class definition but also hides
# multiple third party conflicts, so best of both evils
)
stop = self.stop if self.stop else []
if isinstance(stop, str):
stop = [stop]
return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_tokens,
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids,
stop=stop,
repetition_penalty=repetition_penalty,
seed=self.seed,
ignore_eos=self.ignore_eos,
output_kind=(
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
extra_args=self.vllm_xargs or {},
skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def is_include_output_logprobs(self) -> bool:
"""Check if the request includes output logprobs."""
if self.include is None:
return False
return (
isinstance(self.include, list)
and "message.output_text.logprobs" in self.include
)
@model_validator(mode="before")
def validate_background(cls, data):
if not data.get("background"):
return data
if not data.get("store", True):
raise ValueError("background can only be used when `store` is true")
return data
@model_validator(mode="before")
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise VLLMValidationError(
"prompt template is not supported", parameter="prompt"
)
return data
@model_validator(mode="before")
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data
@model_validator(mode="before")
def function_call_parsing(cls, data):
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
This ensures Pydantic can properly resolve union types in the input field.
Function calls provided as dicts are converted to ResponseFunctionToolCall
objects before validation, while invalid structures are left for Pydantic
to reject with appropriate error messages.
"""
input_data = data.get("input")
# Early return for None, strings, or bytes
# (strings are iterable but shouldn't be processed)
if input_data is None or isinstance(input_data, (str, bytes)):
return data
# Convert iterators (like ValidatorIterator) to list
if not isinstance(input_data, list):
try:
input_data = list(input_data)
except TypeError:
# Not iterable, leave as-is for Pydantic to handle
return data
processed_input = []
for item in input_data:
if isinstance(item, dict) and item.get("type") == "function_call":
try:
processed_input.append(ResponseFunctionToolCall(**item))
except ValidationError:
# Let Pydantic handle validation for malformed function calls
logger.debug(
"Failed to parse function_call to ResponseFunctionToolCall, "
"leaving for Pydantic validation"
)
processed_input.append(item)
else:
processed_input.append(item)
data["input"] = processed_input
return data
class ResponsesResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
created_at: int = Field(default_factory=lambda: int(time.time()))
# error: Optional[ResponseError] = None
incomplete_details: IncompleteDetails | None = None
instructions: str | None = None
metadata: Metadata | None = None
model: str
object: Literal["response"] = "response"
output: list[ResponseOutputItem]
parallel_tool_calls: bool
temperature: float
tool_choice: ToolChoice
tools: list[Tool]
top_p: float
background: bool
max_output_tokens: int
max_tool_calls: int | None = None
previous_response_id: str | None = None
prompt: ResponsePrompt | None = None
reasoning: Reasoning | None = None
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
status: ResponseStatus
text: ResponseTextConfig | None = None
top_logprobs: int | None = None
truncation: Literal["auto", "disabled"]
usage: ResponseUsage | None = None
user: str | None = None
# --8<-- [start:responses-response-extra-params]
# These are populated when enable_response_messages is set to True
# NOTE: custom serialization is needed
# see serialize_input_messages and serialize_output_messages
input_messages: ResponseInputOutputMessage | None = Field(
default=None,
description=(
"If enable_response_messages, we can show raw token input to model."
),
)
output_messages: ResponseInputOutputMessage | None = Field(
default=None,
description=(
"If enable_response_messages, we can show raw token output of model."
),
)
# --8<-- [end:responses-response-extra-params]
# NOTE: openAI harmony doesn't serialize TextContent properly,
# TODO: this fixes for TextContent, but need to verify for tools etc
# https://github.com/openai/harmony/issues/78
@field_serializer("output_messages", when_used="json")
def serialize_output_messages(self, msgs, _info):
return serialize_messages(msgs)
# NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
# https://github.com/openai/harmony/issues/78
@field_serializer("input_messages", when_used="json")
def serialize_input_messages(self, msgs, _info):
return serialize_messages(msgs)
@classmethod
def from_request(
cls,
request: ResponsesRequest,
sampling_params: SamplingParams,
model_name: str,
created_time: int,
output: list[ResponseOutputItem],
status: ResponseStatus,
usage: ResponseUsage | None = None,
input_messages: ResponseInputOutputMessage | None = None,
output_messages: ResponseInputOutputMessage | None = None,
) -> "ResponsesResponse":
incomplete_details: IncompleteDetails | None = None
if status == "incomplete":
incomplete_details = IncompleteDetails(reason="max_output_tokens")
# TODO: implement the other reason for incomplete_details,
# which is content_filter
# incomplete_details = IncompleteDetails(reason='content_filter')
return cls(
id=request.request_id,
created_at=created_time,
incomplete_details=incomplete_details,
instructions=request.instructions,
metadata=request.metadata,
model=model_name,
output=output,
input_messages=input_messages,
output_messages=output_messages,
parallel_tool_calls=request.parallel_tool_calls,
temperature=sampling_params.temperature,
tool_choice=request.tool_choice,
tools=request.tools,
top_p=sampling_params.top_p,
background=request.background,
max_output_tokens=sampling_params.max_tokens,
max_tool_calls=request.max_tool_calls,
previous_response_id=request.previous_response_id,
prompt=request.prompt,
reasoning=request.reasoning,
service_tier=request.service_tier,
status=status,
text=request.text,
top_logprobs=sampling_params.logprobs,
truncation=request.truncation,
user=request.user,
usage=usage,
)
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
content_index: int
"""The index of the content part that is done."""
item_id: str
"""The ID of the output item that the content part was added to."""
output_index: int
"""The index of the output item that the content part was added to."""
part: ResponseReasoningTextContent
"""The content part that is done."""
sequence_number: int
"""The sequence number of this event."""
type: Literal["response.reasoning_part.done"]
"""The type of the event. Always `response.reasoning_part.done`."""
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
content_index: int
"""The index of the content part that is done."""
item_id: str
"""The ID of the output item that the content part was added to."""
output_index: int
"""The index of the output item that the content part was added to."""
part: ResponseReasoningTextContent
"""The content part that is done."""
sequence_number: int
"""The sequence number of this event."""
type: Literal["response.reasoning_part.added"]
"""The type of the event. Always `response.reasoning_part.added`."""
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
response: ResponsesResponse # type: ignore[override]
class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
response: ResponsesResponse # type: ignore[override]
class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
response: ResponsesResponse # type: ignore[override]
StreamingResponsesResponse: TypeAlias = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseCompletedEvent
| ResponseOutputItemAddedEvent
| ResponseOutputItemDoneEvent
| ResponseContentPartAddedEvent
| ResponseContentPartDoneEvent
| ResponseReasoningTextDeltaEvent
| ResponseReasoningTextDoneEvent
| ResponseReasoningPartAddedEvent
| ResponseReasoningPartDoneEvent
| ResponseCodeInterpreterCallInProgressEvent
| ResponseCodeInterpreterCallCodeDeltaEvent
| ResponseWebSearchCallInProgressEvent
| ResponseWebSearchCallSearchingEvent
| ResponseWebSearchCallCompletedEvent
| ResponseCodeInterpreterCallCodeDoneEvent
| ResponseCodeInterpreterCallInterpretingEvent
| ResponseCodeInterpreterCallCompletedEvent
| ResponseMcpCallArgumentsDeltaEvent
| ResponseMcpCallArgumentsDoneEvent
| ResponseMcpCallInProgressEvent
| ResponseMcpCallCompletedEvent
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,798 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Streaming SSE event builders for the Responses API.
Pure functions that translate streaming state + delta data into
OpenAI Response API SSE events. Used by the streaming event
processors in serving.py.
The file is organized as:
1. StreamingState dataclass + utility helpers
2. Shared leaf helpers — delta events (take plain strings, no context)
3. Shared leaf helpers — done events (take plain strings, no context)
4. Harmony-specific dispatchers (route ctx/previous_item → leaf helpers)
5. Harmony-specific tool lifecycle helpers
"""
import json
from dataclasses import dataclass
from typing import Final
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseCodeInterpreterToolCallParam,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseMcpCallArgumentsDeltaEvent,
ResponseMcpCallArgumentsDoneEvent,
ResponseMcpCallCompletedEvent,
ResponseMcpCallInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
response_function_web_search,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai_harmony import Message as HarmonyMessage
from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.entrypoints.openai.responses.context import StreamingHarmonyContext
from vllm.entrypoints.openai.responses.protocol import (
ResponseReasoningPartAddedEvent,
ResponseReasoningPartDoneEvent,
StreamingResponsesResponse,
)
from vllm.utils import random_uuid
TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = {
"python": "code_interpreter",
"container": "container",
"browser": "web_search_preview",
}
def _resolve_mcp_name_label(recipient: str) -> tuple[str, str]:
"""Resolve MCP tool name and server label from a recipient string.
- ``mcp.*`` recipients: strip prefix, use the bare name as both
name and server_label.
- Everything else: use the recipient as the name and look up the
server_label in TOOL_NAME_TO_MCP_SERVER_LABEL.
"""
if recipient.startswith("mcp."):
name = recipient[len("mcp.") :]
return name, name
return recipient, TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
@dataclass
class StreamingState:
"""Mutable state for streaming event processing."""
current_content_index: int = -1
current_output_index: int = 0
current_item_id: str = ""
current_call_id: str = ""
sent_output_item_added: bool = False
is_first_function_call_delta: bool = False
def reset_for_new_item(self) -> None:
"""Reset state when expecting a new output item."""
self.current_output_index += 1
self.sent_output_item_added = False
self.is_first_function_call_delta = False
self.current_call_id = ""
def is_mcp_tool_by_namespace(recipient: str | None) -> bool:
"""
Determine if a tool call is an MCP tool based on recipient prefix.
- Tools starting with "functions." are function calls
- Everything else is an MCP tool
"""
if recipient is None:
return False
# Function calls have "functions." prefix
# Everything else is an MCP tool
return not recipient.startswith("functions.")
# =====================================================================
# Shared leaf helpers — delta events
# =====================================================================
def emit_text_delta_events(
delta: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for text content delta streaming."""
events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"msg_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseOutputMessage(
id=state.current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
)
)
state.current_content_index += 1
events.append(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
content_index=state.current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
)
)
events.append(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=state.current_content_index,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=delta,
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
)
)
return events
def emit_reasoning_delta_events(
delta: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for reasoning text delta streaming."""
events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"msg_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseReasoningItem(
type="reasoning",
id=state.current_item_id,
summary=[],
status="in_progress",
),
)
)
state.current_content_index += 1
events.append(
ResponseReasoningPartAddedEvent(
type="response.reasoning_part.added",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
content_index=state.current_content_index,
part=ResponseReasoningTextContent(
text="",
type="reasoning_text",
),
)
)
events.append(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
delta=delta,
sequence_number=-1,
)
)
return events
def emit_function_call_delta_events(
delta: str,
function_name: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for function call argument deltas."""
events: list[StreamingResponsesResponse] = []
if state.is_first_function_call_delta is False:
state.is_first_function_call_delta = True
state.current_item_id = f"fc_{random_uuid()}"
state.current_call_id = f"call_{random_uuid()}"
tool_call_item = ResponseFunctionToolCall(
name=function_name,
type="function_call",
id=state.current_item_id,
call_id=state.current_call_id,
arguments="",
status="in_progress",
)
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=tool_call_item,
)
)
# Always emit the delta (including on first call)
events.append(
ResponseFunctionCallArgumentsDeltaEvent(
item_id=state.current_item_id,
delta=delta,
output_index=state.current_output_index,
sequence_number=-1,
type="response.function_call_arguments.delta",
)
)
return events
def emit_mcp_delta_events(
delta: str,
state: StreamingState,
recipient: str,
) -> list[StreamingResponsesResponse]:
"""Emit events for MCP tool delta streaming."""
name, server_label = _resolve_mcp_name_label(recipient)
events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"mcp_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=name,
arguments="",
server_label=server_label,
status="in_progress",
),
)
)
events.append(
ResponseMcpCallInProgressEvent(
type="response.mcp_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseMcpCallArgumentsDeltaEvent(
type="response.mcp_call_arguments.delta",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=delta,
)
)
return events
def emit_code_interpreter_delta_events(
delta: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for code interpreter delta streaming."""
events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"tool_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=state.current_item_id,
code=None,
container_id="auto",
outputs=None,
status="in_progress",
),
)
)
events.append(
ResponseCodeInterpreterCallInProgressEvent(
type="response.code_interpreter_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=delta,
)
)
return events
# =====================================================================
# Shared leaf helpers — done events
# =====================================================================
def emit_text_output_done_events(
text: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a final text output item completes."""
text_content = ResponseOutputText(
type="output_text",
text=text,
annotations=[],
)
events: list[StreamingResponsesResponse] = []
events.append(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=text,
logprobs=[],
item_id=state.current_item_id,
)
)
events.append(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=text_content,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseOutputMessage(
id=state.current_item_id,
type="message",
role="assistant",
content=[text_content],
status="completed",
),
)
)
return events
def emit_reasoning_done_events(
text: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a reasoning (analysis) item completes."""
content = ResponseReasoningTextContent(
text=text,
type="reasoning_text",
)
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[content],
status="completed",
id=state.current_item_id,
summary=[],
)
events: list[StreamingResponsesResponse] = []
events.append(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=state.current_item_id,
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=text,
)
)
events.append(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=content,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=reasoning_item,
)
)
return events
def emit_function_call_done_events(
function_name: str,
arguments: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a function call completes."""
events: list[StreamingResponsesResponse] = []
events.append(
ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done",
arguments=arguments,
name=function_name,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
function_call_item = ResponseFunctionToolCall(
type="function_call",
arguments=arguments,
name=function_name,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
call_id=state.current_call_id,
status="completed",
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=function_call_item,
)
)
return events
def emit_mcp_completion_events(
recipient: str,
arguments: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP tool call completes."""
name, server_label = _resolve_mcp_name_label(recipient)
events: list[StreamingResponsesResponse] = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
arguments=arguments,
name=name,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
arguments=arguments,
name=name,
id=state.current_item_id,
server_label=server_label,
status="completed",
),
)
)
return events
# =====================================================================
# Harmony-specific dispatchers
# =====================================================================
def emit_content_delta_events(
ctx: StreamingHarmonyContext,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for content delta streaming based on channel type.
This is a Harmony-specific dispatcher that extracts values from the
Harmony context and delegates to shared leaf helpers.
"""
delta = ctx.last_content_delta
if not delta:
return []
channel = ctx.parser.current_channel
recipient = ctx.parser.current_recipient
if channel in ("final", "commentary") and recipient is None:
# Preambles (commentary with no recipient) and final messages
# are both user-visible text.
return emit_text_delta_events(delta, state)
elif channel == "analysis" and recipient is None:
return emit_reasoning_delta_events(delta, state)
# built-in tools will be triggered on the analysis channel
# However, occasionally built-in tools will
# still be output to commentary.
elif channel in ("commentary", "analysis") and recipient is not None:
if recipient.startswith("functions."):
function_name = recipient[len("functions.") :]
return emit_function_call_delta_events(delta, function_name, state)
elif recipient == "python":
return emit_code_interpreter_delta_events(delta, state)
elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
return emit_mcp_delta_events(delta, state, recipient)
return []
def emit_previous_item_done_events(
previous_item: HarmonyMessage,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit done events for the previous item when expecting a new start.
This is a Harmony-specific dispatcher that extracts values from the
Harmony parser's message object and delegates to shared leaf helpers.
"""
text = previous_item.content[0].text
if previous_item.recipient is not None:
# Deal with tool call
if previous_item.recipient.startswith("functions."):
function_name = previous_item.recipient[len("functions.") :]
return emit_function_call_done_events(function_name, text, state)
elif previous_item.recipient == "python":
return emit_code_interpreter_completion_events(previous_item, state)
elif (
is_mcp_tool_by_namespace(previous_item.recipient)
and state.current_item_id is not None
and state.current_item_id.startswith("mcp_")
):
return emit_mcp_completion_events(previous_item.recipient, text, state)
elif previous_item.channel == "analysis":
return emit_reasoning_done_events(text, state)
elif previous_item.channel in ("commentary", "final"):
# Preambles (commentary with no recipient) and final messages
# are both user-visible text.
return emit_text_output_done_events(text, state)
return []
# =====================================================================
# Harmony-specific tool lifecycle helpers
# =====================================================================
def emit_browser_tool_events(
previous_item: HarmonyMessage,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for browser tool calls (web search)."""
function_name = previous_item.recipient[len("browser.") :]
parsed_args = json.loads(previous_item.content[0].text)
action = None
if function_name == "search":
action = response_function_web_search.ActionSearch(
type="search",
query=parsed_args["query"],
)
elif function_name == "open":
action = response_function_web_search.ActionOpenPage(
type="open_page",
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
elif function_name == "find":
action = response_function_web_search.ActionFind(
type="find",
pattern=parsed_args["pattern"],
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
else:
raise ValueError(f"Unknown function name: {function_name}")
state.current_item_id = f"tool_{random_uuid()}"
events: list[StreamingResponsesResponse] = []
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=response_function_web_search.ResponseFunctionWebSearch(
# TODO: generate a unique id for web search call
type="web_search_call",
id=state.current_item_id,
action=action,
status="in_progress",
),
)
)
events.append(
ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
# enqueue
events.append(
ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseFunctionWebSearch(
type="web_search_call",
id=state.current_item_id,
action=action,
status="completed",
),
)
)
return events
def emit_code_interpreter_completion_events(
previous_item: HarmonyMessage,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when code interpreter completes."""
events: list[StreamingResponsesResponse] = []
events.append(
ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
code=previous_item.content[0].text,
)
)
events.append(
ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=state.current_item_id,
code=previous_item.content[0].text,
container_id="auto",
outputs=[],
status="completed",
),
)
)
return events
def emit_tool_action_events(
ctx: StreamingHarmonyContext,
state: StreamingState,
tool_server: ToolServer | None,
) -> list[StreamingResponsesResponse]:
"""Emit events for tool action turn."""
if not ctx.is_assistant_action_turn() or len(ctx.parser.messages) == 0:
return []
events: list[StreamingResponsesResponse] = []
previous_item = ctx.parser.messages[-1]
# Handle browser tool
if (
tool_server is not None
and tool_server.has_tool("browser")
and previous_item.recipient is not None
and previous_item.recipient.startswith("browser.")
):
events.extend(emit_browser_tool_events(previous_item, state))
# Handle tool completion
if (
tool_server is not None
and previous_item.recipient is not None
and state.current_item_id is not None
and state.sent_output_item_added
):
recipient = previous_item.recipient
if recipient == "python":
events.extend(emit_code_interpreter_completion_events(previous_item, state))
elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
events.extend(
emit_mcp_completion_events(
recipient, previous_item.content[0].text, state
)
)
return events

View File

@@ -0,0 +1,263 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as FunctionCallTool,
)
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response import ToolChoice
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
from vllm import envs
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionMessageParam
from vllm.entrypoints.openai.responses.protocol import ResponseInputOutputItem
def should_continue_final_message(
request_input: str | list[ResponseInputOutputItem],
) -> bool:
"""
Determine if the last input message is a partial assistant message
that should be continued rather than starting a new generation.
This enables partial message completion similar to Anthropic's Messages API,
where users can provide an incomplete assistant message and have the model
continue from where it left off.
A message is considered partial if:
1. It's a ResponseOutputMessage or ResponseReasoningItem
2. Its status is "in_progress" or "incomplete"
Args:
request_input: The input to the Responses API request
Returns:
True if the final message should be continued, False otherwise
"""
if isinstance(request_input, str):
# Simple string input is always a user message
return False
if not request_input:
return False
last_item = request_input[-1]
# Check if the last item is a partial assistant message
if isinstance(last_item, ResponseOutputMessage):
return last_item.status in ("in_progress", "incomplete")
# Check if the last item is a partial reasoning item
if isinstance(last_item, ResponseReasoningItem):
return last_item.status in ("in_progress", "incomplete")
if isinstance(last_item, dict):
# only support partial completion for messages for now
if last_item.get("type", "message") not in ("message", "reasoning"):
return False
return last_item.get("status") in ("in_progress", "incomplete")
return False
def construct_input_messages(
*,
request_instructions: str | None = None,
request_input: str | list[ResponseInputOutputItem],
prev_msg: list[ChatCompletionMessageParam] | None = None,
prev_response_output: list[ResponseOutputItem] | None = None,
):
messages: list[ChatCompletionMessageParam] = []
if request_instructions:
messages.append(
{
"role": "system",
"content": request_instructions,
}
)
# Prepend the conversation history.
if prev_msg is not None:
# Add the previous messages.
messages.extend(prev_msg)
if prev_response_output is not None:
# Add the previous output.
for output_item in prev_response_output:
# NOTE: We skip the reasoning output.
if isinstance(output_item, ResponseOutputMessage):
for content in output_item.content:
messages.append(
{
"role": "assistant",
"content": content.text,
}
)
# Append the new input.
# Responses API supports simple text inputs without chat format.
if isinstance(request_input, str):
messages.append({"role": "user", "content": request_input})
else:
input_messages = construct_chat_messages_with_tool_call(request_input)
messages.extend(input_messages)
return messages
def _maybe_combine_reasoning_and_tool_call(
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
) -> ChatCompletionMessageParam | None:
"""Many models treat MCP calls and reasoning as a single message.
This function checks if the last message is a reasoning message and
the current message is a tool call"""
if not (
isinstance(item, ResponseFunctionToolCall)
and item.id
and item.id.startswith(MCP_PREFIX)
):
return None
if len(messages) == 0:
return None
last_message = messages[-1]
if not (
last_message.get("role") == "assistant"
and last_message.get("reasoning") is not None
):
return None
last_message["tool_calls"] = [
ChatCompletionMessageToolCallParam(
id=item.call_id,
function=FunctionCallTool(
name=item.name,
arguments=item.arguments,
),
type="function",
)
]
return last_message
def construct_chat_messages_with_tool_call(
input_messages: list[ResponseInputOutputItem],
) -> list[ChatCompletionMessageParam]:
"""This function wraps _construct_single_message_from_response_item
Because some chatMessages come from multiple response items
for example a reasoning item and a MCP tool call are two response items
but are one chat message
"""
messages: list[ChatCompletionMessageParam] = []
for item in input_messages:
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
if maybe_combined_message is not None:
messages[-1] = maybe_combined_message
else:
messages.append(_construct_single_message_from_response_item(item))
return messages
def _construct_single_message_from_response_item(
item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam:
if isinstance(item, ResponseFunctionToolCall):
# Append the function call as a tool call.
return ChatCompletionAssistantMessageParam(
role="assistant",
tool_calls=[
ChatCompletionMessageToolCallParam(
id=item.call_id,
function=FunctionCallTool(
name=item.name,
arguments=item.arguments,
),
type="function",
)
],
)
elif isinstance(item, ResponseReasoningItem):
reasoning_content = ""
if item.encrypted_content:
raise ValueError("Encrypted content is not supported.")
if len(item.summary) == 1:
reasoning_content = item.summary[0].text
elif item.content and len(item.content) == 1:
reasoning_content = item.content[0].text
return {
"role": "assistant",
"reasoning": reasoning_content,
}
elif isinstance(item, ResponseOutputMessage):
return {
"role": "assistant",
"content": item.content[0].text,
}
elif isinstance(item, ResponseFunctionToolCallOutputItem):
return ChatCompletionToolMessageParam(
role="tool",
content=item.output,
tool_call_id=item.call_id,
)
elif isinstance(item, dict) and item.get("type") == "function_call_output":
# Append the function call output as a tool message.
return ChatCompletionToolMessageParam(
role="tool",
content=item.get("output"),
tool_call_id=item.get("call_id"),
)
return item # type: ignore
def extract_tool_types(tools: list[Tool]) -> set[str]:
"""
Extracts the tool types from the given tools.
"""
tool_types: set[str] = set()
for tool in tools:
if tool.type == "mcp":
# Allow the MCP Tool type to enable built in tools if the
# server_label is allowlisted in
# envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS
if tool.server_label in envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
tool_types.add(tool.server_label)
else:
tool_types.add(tool.type)
return tool_types
def convert_tool_responses_to_completions_format(tool: dict) -> dict:
"""
Convert a flat tool schema:
{"type": "function", "name": "...", "description": "...", "parameters": {...}}
into:
{"type": "function", "function": {...}}
"""
return {
"type": "function",
"function": tool,
}
def construct_tool_dicts(
tools: list[Tool], tool_choice: ToolChoice
) -> list[dict[str, Any]] | None:
if tools is None or (tool_choice == "none"):
tool_dicts = None
else:
tool_dicts = [
convert_tool_responses_to_completions_format(tool.model_dump())
for tool in tools
]
return tool_dicts

View File

@@ -0,0 +1,843 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import sys
import tempfile
from argparse import Namespace
from collections.abc import Awaitable, Callable
from http import HTTPStatus
from io import BytesIO, StringIO
from typing import Any, TypeAlias
from urllib.parse import urlparse
import aiohttp
import torch
from fastapi import UploadFile
from prometheus_client import start_http_server
from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo
from starlette.datastructures import State
from tqdm import tqdm
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import init_app_state
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from vllm.entrypoints.openai.cli_args import BaseFrontendArgs
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
OpenAIBaseModel,
)
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationRequest,
TranslationResponse,
TranslationResponseVerbose,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
ScoreResponse,
)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
class BatchTranscriptionRequest(TranscriptionRequest):
"""
Batch transcription request that uses file_url instead of file.
This class extends TranscriptionRequest but replaces the file field
with file_url to support batch processing from audio files written in JSON format.
"""
file_url: str = Field(
...,
description=(
"Either a URL of the audio or a data URL with base64 encoded audio data. "
),
)
# Override file to be optional and unused for batch processing
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
@model_validator(mode="before")
@classmethod
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
)
return data
class BatchTranslationRequest(TranslationRequest):
"""
Batch translation request that uses file_url instead of file.
This class extends TranslationRequest but replaces the file field
with file_url to support batch processing from audio files written in JSON format.
"""
file_url: str = Field(
...,
description=(
"Either a URL of the audio or a data URL with base64 encoded audio data. "
),
)
# Override file to be optional and unused for batch processing
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
@model_validator(mode="before")
@classmethod
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
)
return data
BatchRequestInputBody: TypeAlias = (
ChatCompletionRequest
| EmbeddingRequest
| ScoreRequest
| RerankRequest
| BatchTranscriptionRequest
| BatchTranslationRequest
)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameters of the request.
body: BatchRequestInputBody
@field_validator("body", mode="plain")
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url: str = info.data["url"]
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url.endswith("/score"):
return TypeAdapter(ScoreRequest).validate_python(value)
if url.endswith("/rerank"):
return RerankRequest.model_validate(value)
if url == "/v1/audio/transcriptions":
return BatchTranscriptionRequest.model_validate(value)
if url == "/v1/audio/translations":
return BatchTranslationRequest.model_validate(value)
return TypeAdapter(BatchRequestInputBody).validate_python(value)
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
# An unique identifier for the API request.
request_id: str
# The body of the response.
body: (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| None
) = None
class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: BatchResponseData | None
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Any | None
@config
class BatchFrontendArgs(BaseFrontendArgs):
"""Arguments for the batch runner frontend."""
input_file: str | None = None
"""The path or url to a single input file. Currently supports local file
paths, or the http protocol (http or https). If a URL is specified,
the file should be available via HTTP GET."""
output_file: str | None = None
"""The path or url to a single output file. Currently supports
local file paths, or web (http or https) urls. If a URL is specified,
the file should be available via HTTP PUT."""
output_tmp_dir: str | None = None
"""The directory to store the output file before uploading it
to the output URL."""
enable_metrics: bool = False
"""Enable Prometheus metrics"""
host: str | None = None
"""Host name for the Prometheus metrics server
(only needed if enable-metrics is set)."""
port: int = 8000
"""Port number for the Prometheus metrics server
(only needed if enable-metrics is set)."""
url: str = "0.0.0.0"
"""[DEPRECATED] Host name for the Prometheus metrics server
(only needed if enable-metrics is set). Use --host instead."""
@classmethod
def _customize_cli_kwargs(
cls,
frontend_kwargs: dict[str, Any],
) -> dict[str, Any]:
frontend_kwargs = super()._customize_cli_kwargs(frontend_kwargs)
frontend_kwargs["input_file"]["flags"] = ["-i"]
frontend_kwargs["input_file"]["required"] = True
frontend_kwargs["output_file"]["flags"] = ["-o"]
frontend_kwargs["output_file"]["required"] = True
frontend_kwargs["enable_metrics"]["action"] = "store_true"
frontend_kwargs["url"]["deprecated"] = True
return frontend_kwargs
def make_arg_parser(parser: FlexibleArgumentParser):
parser = BatchFrontendArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
def parse_args():
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
args = make_arg_parser(parser).parse_args()
# Backward compatibility: If --url is set, use it for host
url_explicit = any(arg == "--url" or arg.startswith("--url=") for arg in sys.argv)
host_explicit = any(
arg == "--host" or arg.startswith("--host=") for arg in sys.argv
)
if url_explicit and hasattr(args, "url") and not host_explicit:
args.host = args.url
logger.warning_once(
"Using --url for metrics is deprecated. Please use --host instead."
)
return args
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: tqdm | None = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
self._pbar = tqdm(
total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, encoding="utf-8") as f:
return f.read()
async def write_local_file(
output_path: str, batch_outputs: list[BatchRequestOutput]
) -> None:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't affect performance.
with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs:
print(o.model_dump_json(), file=f)
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries = 5
# Number of seconds to wait before retrying.
delay = 5
for attempt in range(1, max_retries + 1):
try:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=1000)
) as session:
if from_file:
with open(data_or_file, "rb") as file:
async with session.put(output_url, data=file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload file.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
else:
async with session.put(output_url, data=data_or_file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload data.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
except Exception as e:
if attempt < max_retries:
logger.error(
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
attempt,
e,
delay,
)
await asyncio.sleep(delay)
else:
raise Exception(
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(
path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str
) -> None:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
if output_tmp_dir is None:
logger.info("Writing outputs to memory buffer")
output_buffer = StringIO()
for o in batch_outputs:
print(o.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(
path_or_url,
output_buffer.read().strip().encode("utf-8"),
from_file=False,
)
else:
# Write responses to a temporary file and then upload it to the URL.
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=output_tmp_dir,
prefix="tmp_batch_output_",
suffix=".jsonl",
) as f:
logger.info("Writing outputs to temporary local file %s", f.name)
await write_local_file(f.name, batch_outputs)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(path_or_url, f.name, from_file=True)
else:
logger.info("Writing outputs to local file %s", path_or_url)
await write_local_file(path_or_url, batch_outputs)
async def download_bytes_from_url(url: str) -> bytes:
"""
Download data from a URL or decode from a data URL.
Args:
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
Returns:
Data as bytes
"""
parsed = urlparse(url)
# Handle data URLs (base64 encoded)
if parsed.scheme == "data":
# Format: data:...;base64,<base64_data>
if "," in url:
header, data = url.split(",", 1)
if "base64" in header:
return base64.b64decode(data)
else:
raise ValueError(f"Unsupported data URL encoding: {header}")
else:
raise ValueError(f"Invalid data URL format: {url}")
# Handle HTTP/HTTPS URLs
elif parsed.scheme in ("http", "https"):
async with (
aiohttp.ClientSession() as session,
session.get(url) as resp,
):
if resp.status != 200:
raise Exception(
f"Failed to download data from URL: {url}. Status: {resp.status}"
)
return await resp.read()
else:
raise ValueError(
f"Unsupported URL scheme: {parsed.scheme}. "
"Supported schemes: http, https, data"
)
def make_error_request_output(
request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=HTTPStatus.BAD_REQUEST,
request_id=f"vllm-batch-{random_uuid()}",
),
error=error_msg,
)
return batch_output
async def make_async_error_request_output(
request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
return make_error_request_output(request, error_msg)
async def run_request(
serving_engine_func: Callable,
request: BatchRequestInput,
tracker: BatchProgressTracker,
) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(
response,
(
ChatCompletionResponse,
EmbeddingResponse,
ScoreResponse,
RerankResponse,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationResponse,
TranslationResponseVerbose,
),
):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=response, request_id=f"vllm-batch-{random_uuid()}"
),
error=None,
)
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=response.error.code,
request_id=f"vllm-batch-{random_uuid()}",
),
error=response,
)
else:
batch_output = make_error_request_output(
request, error_msg="Request must not be sent in stream mode"
)
tracker.completed()
return batch_output
WrapperFn: TypeAlias = Callable[[Callable], Callable]
def handle_endpoint_request(
request: BatchRequestInput,
tracker: BatchProgressTracker,
url_matcher: Callable[[str], bool],
handler_getter: Callable[[], Callable | None],
wrapper_fn: WrapperFn | None = None,
) -> Awaitable[BatchRequestOutput] | None:
"""
Generic handler for endpoint requests.
Args:
request: The batch request input
tracker: Progress tracker for the batch
url_matcher: Function that takes a URL and returns True if it matches
handler_getter: Function that returns the handler function or None
wrapper_fn: Optional function to wrap the handler (e.g., for transcriptions)
Returns:
Awaitable[BatchRequestOutput] if the request was handled,
None if URL didn't match
"""
if not url_matcher(request.url):
return None
handler_fn = handler_getter()
if handler_fn is None:
error_msg = f"Model does not support endpoint: {request.url}"
return make_async_error_request_output(request, error_msg=error_msg)
# Apply wrapper if provided (e.g., for transcriptions/translations)
if wrapper_fn is not None:
handler_fn = wrapper_fn(handler_fn)
tracker.submitted()
return run_request(handler_fn, request, tracker)
def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
"""
Factory function to create a wrapper for transcription/translation handlers.
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
to TranscriptionRequest or TranslationRequest and calls the appropriate handler.
Args:
is_translation: If True, process as translation; otherwise process
as transcription
Returns:
A function that takes a handler and returns a wrapped handler
"""
def wrapper(handler_fn: Callable):
async def transcription_wrapper(
batch_request_body: (BatchTranscriptionRequest | BatchTranslationRequest),
) -> (
TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| ErrorResponse
):
try:
# Download data from URL
audio_data = await download_bytes_from_url(batch_request_body.file_url)
# Create a mock file from the downloaded audio data
mock_file = UploadFile(
file=BytesIO(audio_data),
filename="audio.bin",
)
# Convert batch request to regular request
# by copying all fields except file_url and setting file to mock_file
request_dict = batch_request_body.model_dump(exclude={"file_url"})
request_dict["file"] = mock_file
if is_translation:
# Create TranslationRequest from BatchTranslationRequest
translation_request = TranslationRequest.model_validate(
request_dict
)
return await handler_fn(audio_data, translation_request)
else:
# Create TranscriptionRequest from BatchTranscriptionRequest
transcription_request = TranscriptionRequest.model_validate(
request_dict
)
return await handler_fn(audio_data, transcription_request)
except Exception as e:
operation = "translation" if is_translation else "transcription"
return ErrorResponse(
error=ErrorInfo(
message=f"Failed to process {operation}: {str(e)}",
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST.value,
)
)
return transcription_wrapper
return wrapper
async def build_endpoint_registry(
engine_client: EngineClient,
args: Namespace,
) -> dict[str, dict[str, Any]]:
"""
Build the endpoint registry with all serving objects and handler configurations.
Args:
engine_client: The engine client
args: Command line arguments
Returns:
Dictionary mapping endpoint keys to their configurations
"""
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
# Create a state object to hold serving objects
state = State()
# Initialize all serving objects using init_app_state
# This provides full functionality including chat template processing,
# LoRA support, tool servers, etc.
await init_app_state(engine_client, state, args, supported_tasks)
# Get serving objects from state (defaulting to None if not set)
openai_serving_chat = getattr(state, "openai_serving_chat", None)
openai_serving_embedding = getattr(state, "openai_serving_embedding", None)
openai_serving_scores = getattr(state, "openai_serving_scores", None)
openai_serving_transcription = getattr(state, "openai_serving_transcription", None)
openai_serving_translation = getattr(state, "openai_serving_translation", None)
# Registry of endpoint configurations
endpoint_registry: dict[str, dict[str, Any]] = {
"completions": {
"url_matcher": lambda url: url == "/v1/chat/completions",
"handler_getter": lambda: (
openai_serving_chat.create_chat_completion
if openai_serving_chat is not None
else None
),
"wrapper_fn": None,
},
"embeddings": {
"url_matcher": lambda url: url == "/v1/embeddings",
"handler_getter": lambda: (
openai_serving_embedding.create_embedding
if openai_serving_embedding is not None
else None
),
"wrapper_fn": None,
},
"score": {
"url_matcher": lambda url: url.endswith("/score"),
"handler_getter": lambda: (
openai_serving_scores.create_score
if openai_serving_scores is not None
else None
),
"wrapper_fn": None,
},
"rerank": {
"url_matcher": lambda url: url.endswith("/rerank"),
"handler_getter": lambda: (
openai_serving_scores.do_rerank
if openai_serving_scores is not None
else None
),
"wrapper_fn": None,
},
"transcriptions": {
"url_matcher": lambda url: url == "/v1/audio/transcriptions",
"handler_getter": lambda: (
openai_serving_transcription.create_transcription
if openai_serving_transcription is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=False),
},
"translations": {
"url_matcher": lambda url: url == "/v1/audio/translations",
"handler_getter": lambda: (
openai_serving_translation.create_translation
if openai_serving_translation is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=True),
},
}
return endpoint_registry
def validate_run_batch_args(args):
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
async def run_batch(
engine_client: EngineClient,
args: Namespace,
) -> None:
endpoint_registry = await build_endpoint_registry(
engine_client=engine_client,
args=args,
)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: list[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Use the last segment of the URL as the endpoint key.
# More advanced URL matching is done in url_matcher of endpoint_registry.
endpoint_key = request.url.split("/")[-1]
result = None
if endpoint_key in endpoint_registry:
endpoint_config = endpoint_registry[endpoint_key]
result = handle_endpoint_request(
request,
tracker,
url_matcher=endpoint_config["url_matcher"],
handler_getter=endpoint_config["handler_getter"],
wrapper_fn=endpoint_config["wrapper_fn"],
)
if result is not None:
response_futures.append(result)
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /v1/audio/transcriptions, /v1/audio/translations, /score, "
" /rerank. See vllm/entrypoints/openai/api_server.py "
"for supported score/rerank versions.",
)
)
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
await write_file(args.output_file, responses, args.output_tmp_dir)
async def main(args: Namespace):
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.usage.usage_lib import UsageContext
validate_run_batch_args(args)
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
await run_batch(engine_client, args)
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
logger.info("args: %s", args)
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.host)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(main(args))

View File

@@ -0,0 +1,382 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import hashlib
import json
import secrets
import uuid
from argparse import Namespace
from collections.abc import Awaitable
from contextlib import asynccontextmanager
from http import HTTPStatus
import pydantic
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from vllm import envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
from vllm.entrypoints.utils import sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.utils.gc_utils import freeze_gc_heap
logger = init_logger("vllm.entrypoints.openai.server_utils")
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes
-----
There are two cases in which authentication is skipped:
1. The HTTP method is OPTIONS.
2. The request path doesn't start with /v1 (e.g. /health).
"""
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
# scope["type"] can be "lifespan" or "startup" for example,
# in which case we don't need to do anything
return self.app(scope, receive, send)
root_path = scope.get("root_path", "")
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and not self.verify_token(headers):
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
return response(scope, receive, send)
return self.app(scope, receive, send)
class XRequestIdMiddleware:
"""
Middleware the set's the X-Request-Id header for each response
to a random uuid4 (hex) value if the header isn't already
present in the request, otherwise use the provided request id.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket"):
return self.app(scope, receive, send)
# Extract the request headers.
request_headers = Headers(scope=scope)
async def send_with_request_id(message: Message) -> None:
"""
Custom send function to mutate the response headers
and append X-Request-Id to it.
"""
if message["type"] == "http.response.start":
response_headers = MutableHeaders(raw=message["headers"])
request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
response_headers.append("X-Request-Id", request_id)
await send(message)
return self.app(scope, receive, send_with_request_id)
def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file:
return None
try:
with open(log_config_file) as f:
return json.load(f)
except Exception as e:
logger.warning(
"Failed to load log config from file %s: error %s", log_config_file, e
)
return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
def _extract_content_from_chunk(chunk_data: dict) -> str:
"""Extract content from a streaming response chunk."""
try:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionStreamResponse,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionStreamResponse,
)
# Try using Completion types for type-safe parsing
if chunk_data.get("object") == "chat.completion.chunk":
chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
if chat_response.choices and chat_response.choices[0].delta.content:
return chat_response.choices[0].delta.content
elif chunk_data.get("object") == "text_completion":
completion_response = CompletionStreamResponse.model_validate(chunk_data)
if completion_response.choices and completion_response.choices[0].text:
return completion_response.choices[0].text
except pydantic.ValidationError:
# Fallback to manual parsing
if "choices" in chunk_data and chunk_data["choices"]:
choice = chunk_data["choices"][0]
if "delta" in choice and choice["delta"].get("content"):
return choice["delta"]["content"]
elif choice.get("text"):
return choice["text"]
return ""
class SSEDecoder:
"""Robust Server-Sent Events decoder for streaming responses."""
def __init__(self):
self.buffer = ""
self.content_buffer = []
def decode_chunk(self, chunk: bytes) -> list[dict]:
"""Decode a chunk of SSE data and return parsed events."""
import json
try:
chunk_str = chunk.decode("utf-8")
except UnicodeDecodeError:
# Skip malformed chunks
return []
self.buffer += chunk_str
events = []
# Process complete lines
while "\n" in self.buffer:
line, self.buffer = self.buffer.split("\n", 1)
line = line.rstrip("\r") # Handle CRLF
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
events.append({"type": "done"})
elif data_str:
try:
event_data = json.loads(data_str)
events.append({"type": "data", "data": event_data})
except json.JSONDecodeError:
# Skip malformed JSON
continue
return events
def extract_content(self, event_data: dict) -> str:
"""Extract content from event data."""
return _extract_content_from_chunk(event_data)
def add_content(self, content: str) -> None:
"""Add content to the buffer."""
if content:
self.content_buffer.append(content)
def get_complete_content(self) -> str:
"""Get the complete buffered content."""
return "".join(self.content_buffer)
def _log_streaming_response(response, response_body: list) -> None:
"""Log streaming response with robust SSE parsing."""
from starlette.concurrency import iterate_in_threadpool
sse_decoder = SSEDecoder()
chunk_count = 0
def buffered_iterator():
nonlocal chunk_count
for chunk in response_body:
chunk_count += 1
yield chunk
# Parse SSE events from chunk
events = sse_decoder.decode_chunk(chunk)
for event in events:
if event["type"] == "data":
content = sse_decoder.extract_content(event["data"])
sse_decoder.add_content(content)
elif event["type"] == "done":
# Log complete content when done
full_content = sse_decoder.get_complete_content()
if full_content:
# Truncate if too long
if len(full_content) > 2048:
full_content = full_content[:2048] + ""
"...[truncated]"
logger.info(
"response_body={streaming_complete: content=%r, chunks=%d}",
full_content,
chunk_count,
)
else:
logger.info(
"response_body={streaming_complete: no_content, chunks=%d}",
chunk_count,
)
return
response.body_iterator = iterate_in_threadpool(buffered_iterator())
logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
def _log_non_streaming_response(response_body: list) -> None:
"""Log non-streaming response."""
try:
decoded_body = response_body[0].decode()
logger.info("response_body={%s}", decoded_body)
except UnicodeDecodeError:
logger.info("response_body={<binary_data>}")
async def log_response(request: Request, call_next):
response = await call_next(request)
response_body = [section async for section in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body))
# Check if this is a streaming response by looking at content-type
content_type = response.headers.get("content-type", "")
is_streaming = content_type == "text/event-stream; charset=utf-8"
# Log response body based on type
if not response_body:
logger.info("response_body={<empty>}")
elif is_streaming:
_log_streaming_response(response, response_body)
else:
_log_non_streaming_response(response_body)
return response
async def http_exception_handler(_: Request, exc: HTTPException):
err = ErrorResponse(
error=ErrorInfo(
message=sanitize_message(exc.detail),
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
async def validation_exception_handler(_: Request, exc: RequestValidationError):
param = None
errors = exc.errors()
for error in errors:
if "ctx" in error and "error" in error["ctx"]:
ctx_error = error["ctx"]["error"]
if isinstance(ctx_error, VLLMValidationError):
param = ctx_error.parameter
break
exc_str = str(exc)
errors_str = str(errors)
if errors and errors_str and errors_str != exc_str:
message = f"{exc_str} {errors_str}"
else:
message = exc_str
err = ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST,
param=param,
)
)
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
_running_tasks: set[asyncio.Task] = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, FastAPI, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponseVariant,
TranslationRequest,
TranslationResponseVariant,
)
from vllm.entrypoints.openai.speech_to_text.serving import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
logger = init_logger(__name__)
router = APIRouter()
def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
def translation(request: Request) -> OpenAIServingTranslation:
return request.app.state.openai_serving_translation
@router.post(
"/v1/audio/transcriptions",
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_transcriptions(
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
):
handler = transcription(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Transcriptions API"
)
audio_data = await request.file.read()
try:
generator = await handler.create_transcription(audio_data, request, raw_request)
except Exception as e:
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TranscriptionResponseVariant):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/audio/translations",
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_translations(
request: Annotated[TranslationRequest, Form()], raw_request: Request
):
handler = translation(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Translations API"
)
audio_data = await request.file.read()
try:
generator = await handler.create_translation(audio_data, request, raw_request)
except Exception as e:
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TranslationResponseVariant):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
def attach_router(app: FastAPI):
app.include_router(router)
def init_transcription_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
state.openai_serving_transcription = (
OpenAIServingTranscription(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
state.openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)

View File

@@ -0,0 +1,545 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from http import HTTPStatus
from typing import Literal, TypeAlias
import torch
from fastapi import HTTPException, UploadFile
from pydantic import (
Field,
model_validator,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
OpenAIBaseModel,
UsageInfo,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.sampling_params import (
RequestOutputKind,
SamplingParams,
)
from vllm.utils import random_uuid
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: str | None = None
stop_reason: int | str | None = None
class TranscriptionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
object: Literal["transcription.chunk"] = "transcription.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranscriptionResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)
## Protocols for Audio
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
class TranscriptionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/audio/createTranscription
file: UploadFile
"""
The audio file object (not file name) to transcribe, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: str | None = None
"""ID of the model to use.
"""
language: str | None = None
"""The language of the input audio.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy and latency.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
timestamp_granularities: list[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[]
)
"""The timestamp granularities to populate for this transcription.
`response_format` must be set `verbose_json` to use timestamp granularities.
Either or both of these options are supported: `word`, or `segment`. Note:
There is no additional latency for segment timestamps, but generating word
timestamps incurs additional latency.
"""
stream: bool | None = False
"""When set, it will enable output to be streamed in a similar fashion
as the Chat Completion endpoint.
"""
# --8<-- [start:transcription-extra-params]
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
vllm_xargs: dict[str, str | int | float] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:transcription-extra-params]
to_language: str | None = None
"""The language of the output audio we transcribe to.
Please note that this is not currently used by supported models at this
time, but it is a placeholder for future use, matching translation api.
"""
# --8<-- [start:transcription-sampling-params]
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
top_p: float | None = None
"""Enables nucleus (top-p) sampling, where tokens are selected from the
smallest possible set whose cumulative probability exceeds `p`.
"""
top_k: int | None = None
"""Limits sampling to the `k` most probable tokens at each step."""
min_p: float | None = None
"""Filters out tokens with a probability lower than `min_p`, ensuring a
minimum likelihood threshold during sampling.
"""
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
frequency_penalty: float | None = 0.0
"""The frequency penalty to use for sampling."""
repetition_penalty: float | None = None
"""The repetition penalty to use for sampling."""
presence_penalty: float | None = 0.0
"""The presence penalty to use for sampling."""
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:transcription-sampling-params]
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
"min_p": 0.0,
}
def to_sampling_params(
self, default_max_tokens: int, default_sampling_params: dict | None = None
) -> SamplingParams:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
)
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
return SamplingParams.from_optional(
temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
presence_penalty=self.presence_penalty,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
extra_args=self.vllm_xargs,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def validate_transcription_request(cls, data):
if isinstance(data.get("file"), str):
raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Expected 'file' to be a file-like object, not 'str'.",
)
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
# Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data
# Transcription response objects
class TranscriptionUsageAudio(OpenAIBaseModel):
type: Literal["duration"] = "duration"
seconds: int
class TranscriptionResponse(OpenAIBaseModel):
text: str
"""The transcribed text."""
usage: TranscriptionUsageAudio
class TranscriptionWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranscriptionSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float | None = None
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: list[int]
"""Array of token IDs for the text content."""
class TranscriptionResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The transcribed text."""
segments: list[TranscriptionSegment] | None = None
"""Segments of the transcribed text and their corresponding details."""
words: list[TranscriptionWord] | None = None
"""Extracted words and their corresponding timestamps."""
TranscriptionResponseVariant: TypeAlias = (
TranscriptionResponse | TranscriptionResponseVerbose
)
class TranslationResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: str | None = None
stop_reason: int | str | None = None
class TranslationStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
object: Literal["translation.chunk"] = "translation.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranslationResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)
class TranslationRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/audio/createTranslation
file: UploadFile
"""
The audio file object (not file name) to translate, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: str | None = None
"""ID of the model to use.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
# TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params]
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
# --8<-- [end:translation-sampling-params]
# --8<-- [start:translation-extra-params]
language: str | None = None
"""The language of the input audio we translate from.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy.
"""
to_language: str | None = None
"""The language of the input audio we translate to.
Please note that this is not supported by all models, refer to the specific
model documentation for more details.
For instance, Whisper only supports `to_language=en`.
"""
stream: bool | None = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:translation-extra-params]
# Default sampling parameters for translation requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
}
def to_sampling_params(
self, default_max_tokens: int, default_sampling_params: dict | None = None
) -> SamplingParams:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
return SamplingParams.from_optional(
temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
# Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data
# Translation response objects
class TranslationResponse(OpenAIBaseModel):
text: str
"""The translated text."""
class TranslationWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranslationSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float | None = None
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: list[int]
"""Array of token IDs for the text content."""
class TranslationResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The translated text."""
segments: list[TranslationSegment] | None = None
"""Segments of the translated text and their corresponding details."""
words: list[TranslationWord] | None = None
"""Extracted words and their corresponding timestamps."""
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose

View File

@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import AsyncGenerator
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionResponseVerbose,
TranscriptionStreamResponse,
TranslationRequest,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationStreamResponse,
)
from vllm.entrypoints.openai.speech_to_text.speech_to_text import OpenAISpeechToText
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
logger = init_logger(__name__)
class OpenAIServingTranscription(OpenAISpeechToText):
"""Handles transcription requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_transcription(
self,
audio_data: bytes,
request: TranscriptionRequest,
raw_request: Request | None = None,
) -> (
TranscriptionResponse
| TranscriptionResponseVerbose
| AsyncGenerator[str, None]
| ErrorResponse
):
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=(
TranscriptionResponseVerbose
if request.response_format == "verbose_json"
else TranscriptionResponse
),
stream_generator_method=self.transcription_stream_generator,
)
async def transcription_stream_generator(
self,
request: TranscriptionRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="transcription.chunk",
response_stream_choice_class=TranscriptionResponseStreamChoice,
stream_response_class=TranscriptionStreamResponse,
)
async for chunk in generator:
yield chunk
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_translation(
self,
audio_data: bytes,
request: TranslationRequest,
raw_request: Request | None = None,
) -> (
TranslationResponse
| TranslationResponseVerbose
| AsyncGenerator[str, None]
| ErrorResponse
):
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=(
TranslationResponseVerbose
if request.response_format == "verbose_json"
else TranslationResponse
),
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self,
request: TranslationRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
)
async for chunk in generator:
yield chunk

View File

@@ -0,0 +1,770 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
import zlib
from collections.abc import AsyncGenerator, Callable
from functools import cached_property
from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
ErrorResponse,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionResponseVerbose,
TranscriptionSegment,
TranscriptionStreamResponse,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationSegment,
TranslationStreamResponse,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import (
SupportsTranscription,
supports_transcription,
)
from vllm.multimodal.audio import split_audio
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
)
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
T = TypeVar("T", bound=SpeechToTextResponse)
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
S = TypeVar("S", bound=SpeechToTextSegment)
ResponseType: TypeAlias = (
TranscriptionResponse
| TranslationResponse
| TranscriptionResponseVerbose
| TranslationResponseVerbose
)
logger = init_logger(__name__)
class OpenAISpeechToText(OpenAIServing):
"""Base class for speech-to-text operations like transcription and
translation."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
task_type: Literal["transcribe", "translate"] = "transcribe",
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type: Final = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type
)
self.enable_force_include_usage = enable_force_include_usage
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
if self.model_cls.supports_segment_timestamp:
self.tokenizer = cast(
PreTrainedTokenizerBase,
get_tokenizer(
tokenizer_name=self.model_config.tokenizer,
tokenizer_mode=self.model_config.tokenizer_mode,
),
)
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params,
)
# Warm up audio preprocessing to avoid first-request latency
self._warmup_audio_preprocessing()
# Warm up input processor with dummy audio
self._warmup_input_processor()
def _warmup_audio_preprocessing(self) -> None:
"""Warm up audio processing libraries to avoid first-request latency.
The first call to librosa functions (load, get_duration, mel-spectrogram)
triggers JIT compilation and library initialization which can take ~7s.
This method warms up these operations during server initialization.
"""
# Skip warmup if librosa is not installed (optional dependency)
if isinstance(librosa, PlaceholderModule):
return
# Skip warmup if model doesn't support transcription
if not supports_transcription(self.model_cls):
return
if getattr(self.model_cls, "skip_warmup_audio_preprocessing", False):
return
try:
warmup_start = time.perf_counter()
logger.info("Warming up audio preprocessing libraries...")
# Create a minimal dummy audio (1 second of silence at target sample rate)
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
# Warm up librosa.load by using librosa functions on the dummy data
# This initializes FFTW, numba JIT, and other audio processing libraries
_ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)
# Warm up mel-spectrogram computation with model-specific parameters
from vllm.transformers_utils.processor import cached_processor_from_config
processor = cached_processor_from_config(self.model_config)
feature_extractor = None
if hasattr(processor, "feature_extractor"):
feature_extractor = processor.feature_extractor
elif hasattr(processor, "audio_processor"):
# For models like GraniteSpeech that use audio_processor
audio_proc = processor.audio_processor
if hasattr(audio_proc, "feature_extractor"):
feature_extractor = audio_proc.feature_extractor
# If audio_processor doesn't have feature_extractor,
# skip mel-spectrogram warmup for these models
if feature_extractor is not None:
_ = librosa.feature.melspectrogram(
y=dummy_audio,
sr=self.asr_config.sample_rate,
n_mels=getattr(feature_extractor, "n_mels", 128),
n_fft=getattr(feature_extractor, "n_fft", 400),
hop_length=getattr(feature_extractor, "hop_length", 160),
)
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
except Exception:
# Don't fail initialization if warmup fails - log exception and continue
logger.exception(
"Audio preprocessing warmup failed (non-fatal): %s. "
"First request may experience higher latency.",
)
def _warmup_input_processor(self) -> None:
"""Warm up input processor with dummy audio to avoid first-request latency.
The first call to renderer.render_cmpl() with multimodal audio
triggers multimodal processing initialization which can take ~2.5s.
This method processes a dummy audio request to warm up the pipeline.
"""
# Skip warmup if model doesn't support transcription
if not supports_transcription(self.model_cls):
return
# Only warm up if model supports transcription methods
if not hasattr(self.model_cls, "get_generation_prompt"):
return
try:
warmup_start = time.perf_counter()
logger.info("Warming up multimodal input processor...")
# Create minimal dummy audio (1 second of silence)
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
# Use the same method that _preprocess_speech_to_text uses
# to create the prompt
dummy_prompt = self.model_cls.get_generation_prompt(
audio=dummy_audio,
stt_config=self.asr_config,
model_config=self.model_config,
language="en",
task_type=self.task_type,
request_prompt="",
to_language=None,
)
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
# Process the dummy input through the input processor
# This will trigger all the multimodal processing initialization
_ = self.renderer.render_cmpl([parsed_prompt])
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
except Exception:
# Don't fail initialization if warmup fails - log warning and continue
logger.exception(
"Input processor warmup failed (non-fatal): %s. "
"First request may experience higher latency."
)
@cached_property
def model_cls(self) -> type[SupportsTranscription]:
from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsTranscription], model_cls)
async def _detect_language(
self,
audio_chunk: np.ndarray,
request_id: str,
) -> str:
"""Auto-detect the spoken language from an audio chunk.
Delegates prompt construction and output parsing to the model class
via ``get_language_detection_prompt`` and
``parse_language_detection_output``.
"""
from vllm.sampling_params import SamplingParams
prompt = self.model_cls.get_language_detection_prompt(
audio_chunk,
self.asr_config,
)
allowed_token_ids = self.model_cls.get_language_token_ids(
self.tokenizer,
)
sampling_params = SamplingParams(
max_tokens=1,
temperature=0.0,
allowed_token_ids=allowed_token_ids,
)
result_generator = self.engine_client.generate(
prompt,
sampling_params,
request_id,
)
final_output: RequestOutput
async for final_output in result_generator:
if final_output.finished:
break
token_ids = list(final_output.outputs[0].token_ids)
lang = self.model_cls.parse_language_detection_output(
token_ids,
self.tokenizer,
)
logger.info("Auto-detected language: '%s'", lang)
return lang
async def _preprocess_speech_to_text(
self,
request: SpeechToTextRequest,
audio_data: bytes,
request_id: str,
) -> tuple[list[ProcessorInputs], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
to_language = (
self.model_cls.validate_language(request.to_language)
if request.to_language
else None
)
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise VLLMValidationError(
"Maximum file size exceeded",
parameter="audio_filesize_mb",
value=len(audio_data) / 1024**2,
)
with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (
self.asr_config.allow_audio_chunking
and duration > self.asr_config.max_audio_clip_s
)
if not do_split_audio:
chunks = [y]
else:
assert self.asr_config.max_audio_clip_s is not None
assert self.asr_config.min_energy_split_window_size is not None
chunks = split_audio(
audio_data=y,
sample_rate=int(sr),
max_clip_duration_s=self.asr_config.max_audio_clip_s,
overlap_duration_s=self.asr_config.overlap_chunk_second,
min_energy_window_size=self.asr_config.min_energy_split_window_size,
)
if language is None and getattr(
self.model_cls, "supports_explicit_language_detection", False
):
# Auto-detect language from the first chunk.
language = await self._detect_language(
chunks[0], f"{request_id}-lang_detect"
)
request.language = language
parsed_prompts: list[DictPrompt] = []
for chunk in chunks:
# The model has control over the construction, as long as it
# returns a valid PromptType.
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
)
parsed_prompt: DictPrompt
if request.response_format == "verbose_json":
parsed_prompt = parse_enc_dec_prompt(prompt)
parsed_prompt = self._preprocess_verbose_prompt(parsed_prompt)
else:
parsed_prompt = parse_model_prompt(self.model_config, prompt)
parsed_prompts.append(parsed_prompt)
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
return engine_prompts, duration
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"]
if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
raise VLLMValidationError(
"Expected decoder_prompt to contain text",
parameter="decoder_prompt",
value=type(dec_prompt).__name__,
)
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
"<|notimestamps|>", "<|0.00|>"
)
return prompt
def _get_verbose_segments(
self,
tokens: tuple,
log_probs: FlatLogprobs | list[dict[int, Logprob]],
request: SpeechToTextRequest,
segment_class: type[SpeechToTextSegment],
start_time: float = 0,
) -> list[SpeechToTextSegment]:
"""
Convert tokens to verbose segments.
This method expects the model to produce
timestamps as tokens (similar to Whisper).
If the tokens do not include timestamp information,
the segments may not be generated correctly.
Note: No_speech_prob field is not supported
in this implementation and will be None. See docs for details.
"""
BASE_OFFSET = 0.02
init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
if tokens[-1] == self.tokenizer.eos_token_id:
tokens = tokens[:-1]
tokens_with_start = (init_token,) + tokens
segments: list[SpeechToTextSegment] = []
last_timestamp_start = 0
if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
avg_logprob = 0.0
for idx in range(1, len(tokens_with_start)):
# Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
# If the ordering is violated, this slicing may produce incorrect results.
token = tokens_with_start[idx]
if token >= init_token and tokens_with_start[idx - 1] >= init_token:
sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
start_timestamp = sliced_timestamp_tokens[0] - init_token
end_timestamp = sliced_timestamp_tokens[-1] - init_token
text = self.tokenizer.decode(sliced_timestamp_tokens[1:-1])
text_bytes = text.encode("utf-8")
casting_segment = cast(
SpeechToTextSegment,
segment_class(
id=len(segments),
seek=start_time,
start=start_time + BASE_OFFSET * start_timestamp,
end=start_time + BASE_OFFSET * end_timestamp,
temperature=request.temperature,
text=text,
# The compression ratio measures
# how compressible the generated text is.
# A higher ratio indicates more repetitive content,
# which is a strong sign of hallucination in outputs.
compression_ratio=len(text_bytes)
/ len(zlib.compress(text_bytes)),
tokens=sliced_timestamp_tokens[1:-1],
avg_logprob=avg_logprob / (idx - last_timestamp_start),
),
)
segments.append(casting_segment)
last_timestamp_start = idx
avg_logprob = 0
else:
avg_logprob += log_probs[idx - 1][token].logprob
return segments
async def _create_speech_to_text(
self,
audio_data: bytes,
request: SpeechToTextRequest,
raw_request: Request,
response_class: type[ResponseType],
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
"""Base method for speech-to-text operations like transcription and
translation."""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
if request.response_format not in ["text", "json", "verbose_json"]:
return self.create_error_response(
"Currently only support response_format: "
"`text`, `json` or `verbose_json`"
)
if (
request.response_format == "verbose_json"
and not self.model_cls.supports_segment_timestamp
):
return self.create_error_response(
f"Currently do not support verbose_json for {request.model}"
)
if request.response_format == "verbose_json" and request.stream:
return self.create_error_response(
"verbose_json format doesn't support streaming case"
)
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
engine_prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
request_id=request_id,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens,
0,
self.default_sampling_params,
)
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
)
if request.response_format == "verbose_json":
sampling_params.logprobs = 1
list_result_generator = []
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}_{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
)
list_result_generator.append(generator)
except ValueError as e:
return self.create_error_response(e)
if request.stream:
return stream_generator_method(
request, list_result_generator, request_id, request_metadata, duration_s
)
# Non-streaming response.
total_segments = []
text_parts = []
try:
assert list_result_generator is not None
segments_types: dict[str, type[SpeechToTextSegment]] = {
"transcribe": TranscriptionSegment,
"translate": TranslationSegment,
}
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
text = ""
chunk_size_in_s = self.asr_config.max_audio_clip_s
if chunk_size_in_s is None:
assert len(list_result_generator) == 1, (
"`max_audio_clip_s` is set to None, audio cannot be chunked"
)
for idx, result_generator in enumerate(list_result_generator):
start_time = (
float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
)
async for op in result_generator:
if request.response_format == "verbose_json":
assert op.outputs[0].logprobs
segments: list[SpeechToTextSegment] = (
self._get_verbose_segments(
tokens=tuple(op.outputs[0].token_ids),
segment_class=segment_class,
request=request,
start_time=start_time,
log_probs=op.outputs[0].logprobs,
)
)
total_segments.extend(segments)
text_parts.extend([seg.text for seg in segments])
else:
raw_text = op.outputs[0].text
text_parts.append(self.model_cls.post_process_output(raw_text))
text = "".join(text_parts)
if self.task_type == "transcribe":
final_response: ResponseType
# add usage in TranscriptionResponse.
usage = {
"type": "duration",
# rounded up as per openAI specs
"seconds": int(math.ceil(duration_s)),
}
if request.response_format != "verbose_json":
final_response = cast(
T, TranscriptionResponse(text=text, usage=usage)
)
else:
final_response = cast(
V,
TranscriptionResponseVerbose(
text=text,
language=request.language,
duration=str(duration_s),
segments=total_segments,
),
)
else:
# no usage in response for translation task
if request.response_format != "verbose_json":
final_response = cast(T, TranslationResponse(text=text))
else:
final_response = cast(
V,
TranslationResponseVerbose(
text=text,
language=request.language,
duration=str(duration_s),
segments=total_segments,
),
)
return final_response
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
async def _speech_to_text_stream_generator(
self,
request: SpeechToTextRequest,
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
response_stream_choice_class: type[TranscriptionResponseStreamChoice]
| type[TranslationResponseStreamChoice],
stream_response_class: type[TranscriptionStreamResponse]
| type[TranslationStreamResponse],
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
model_name = request.model
completion_tokens = 0
num_prompt_tokens = 0
include_usage = self.enable_force_include_usage or request.stream_include_usage
include_continuous_usage = (
request.stream_continuous_usage_stats
if include_usage and request.stream_continuous_usage_stats
else False
)
try:
for result_generator in list_result_generator:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if audio_tokens := self.model_cls.get_num_audio_tokens(
audio_duration_s, self.asr_config, self.model_config
):
num_prompt_tokens += audio_tokens
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]
# TODO: For models that output structured formats (e.g.,
# Qwen3-ASR with "language X<asr_text>" prefix), streaming
# would need buffering to strip the prefix properly since
# deltas may split the tag across chunks.
delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = response_stream_choice_class(delta=delta_message)
else:
# Model is finished generating.
choice_data = response_stream_choice_class(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name,
)
# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
if include_usage:
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
final_usage_chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
except Exception as e:
logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"The 'vllm.entrypoints.openai.translations' module has been renamed to "
"'vllm.entrypoints.openai.speech_to_text'. Please update your imports. "
"This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.api_router' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.api_router'. Please update your "
"imports. This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.api_router import * # noqa: F401,F403,E402

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.protocol' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.protocol'. Please update your "
"imports. This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.protocol import * # noqa: F401,F403,E402

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.serving' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.serving'. Please update your "
"imports. This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.serving import * # noqa: F401,F403,E402

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.speech_to_text' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.speech_to_text'. Please update "
"your imports. This backward-compatible alias will be removed in version "
"0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.speech_to_text import * # noqa: F401,F403,E402

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeVar
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
)
# Used internally
_ChatCompletionResponseChoiceT = TypeVar(
"_ChatCompletionResponseChoiceT",
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
)
def maybe_filter_parallel_tool_calls(
choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
) -> _ChatCompletionResponseChoiceT:
"""Filter to first tool call only when parallel_tool_calls is False."""
if request.parallel_tool_calls:
return choice
if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
choice.message.tool_calls = choice.message.tool_calls[:1]
elif (
isinstance(choice, ChatCompletionResponseStreamChoice)
and choice.delta.tool_calls
):
choice.delta.tool_calls = [
tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
]
return choice
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise RequestValidationError(
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
)