update
This commit is contained in:
0
vllm/entrypoints/openai/__init__.py
Normal file
0
vllm/entrypoints/openai/__init__.py
Normal file
545
vllm/entrypoints/openai/api_server.py
Normal file
545
vllm/entrypoints/openai/api_server.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import multiprocessing.forkserver as forkserver
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import uvloop
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.datastructures import State
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.server_utils import (
|
||||
get_uvicorn_log_config,
|
||||
http_exception_handler,
|
||||
lifespan,
|
||||
log_response,
|
||||
validation_exception_handler,
|
||||
)
|
||||
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
ScalingMiddleware,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
log_non_default_args,
|
||||
log_version_and_model,
|
||||
process_lora_modules,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.tracing import instrument
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import decorate_logs, set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger("vllm.entrypoints.openai.api_server")
|
||||
|
||||
_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace,
|
||||
*,
|
||||
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
|
||||
disable_frontend_multiprocessing: bool | None = None,
|
||||
client_config: dict[str, Any] | None = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
|
||||
# The executor is expected to be mp.
|
||||
# Pre-import heavy modules in the forkserver process
|
||||
logger.debug("Setup forkserver with pre-imports")
|
||||
multiprocessing.set_start_method("forkserver")
|
||||
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
|
||||
forkserver.ensure_running()
|
||||
logger.debug("Forkserver setup complete!")
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
if client_config:
|
||||
engine_args._api_process_count = client_config.get("client_count", 1)
|
||||
engine_args._api_process_rank = client_config.get("client_index", 0)
|
||||
|
||||
if disable_frontend_multiprocessing is None:
|
||||
disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
usage_context=usage_context,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
client_config=client_config,
|
||||
) as engine:
|
||||
yield engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
*,
|
||||
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
client_config: dict[str, Any] | None = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Create the EngineConfig (determines if we can use V1).
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
if disable_frontend_multiprocessing:
|
||||
logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
async_llm: AsyncLLM | None = None
|
||||
|
||||
# Don't mutate the input client_config
|
||||
client_config = dict(client_config) if client_config else {}
|
||||
client_count = client_config.pop("client_count", 1)
|
||||
client_index = client_config.pop("client_index", 0)
|
||||
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
enable_log_requests=engine_args.enable_log_requests,
|
||||
aggregate_engine_logging=engine_args.aggregate_engine_logging,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
client_addresses=client_config,
|
||||
client_count=client_count,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
assert async_llm is not None
|
||||
await async_llm.reset_mm_cache()
|
||||
|
||||
yield async_llm
|
||||
finally:
|
||||
if async_llm:
|
||||
async_llm.shutdown()
|
||||
|
||||
|
||||
def build_app(
|
||||
args: Namespace, supported_tasks: tuple["SupportedTask", ...] | None = None
|
||||
) -> FastAPI:
|
||||
if supported_tasks is None:
|
||||
warnings.warn(
|
||||
"The 'supported_tasks' parameter was not provided to "
|
||||
"build_app and will be required in a future version. "
|
||||
"Defaulting to ('generate',).",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
supported_tasks = _FALLBACK_SUPPORTED_TASKS
|
||||
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(
|
||||
openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
|
||||
)
|
||||
elif args.enable_offline_docs:
|
||||
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.state.args = args
|
||||
|
||||
from vllm.entrypoints.serve import register_vllm_serve_api_routers
|
||||
|
||||
register_vllm_serve_api_routers(app)
|
||||
|
||||
from vllm.entrypoints.openai.models.api_router import (
|
||||
attach_router as register_models_api_router,
|
||||
)
|
||||
|
||||
register_models_api_router(app)
|
||||
|
||||
from vllm.entrypoints.sagemaker.api_router import (
|
||||
attach_router as register_sagemaker_api_router,
|
||||
)
|
||||
|
||||
register_sagemaker_api_router(app, supported_tasks)
|
||||
|
||||
if "generate" in supported_tasks:
|
||||
from vllm.entrypoints.openai.generate.api_router import (
|
||||
register_generate_api_routers,
|
||||
)
|
||||
|
||||
register_generate_api_routers(app)
|
||||
|
||||
from vllm.entrypoints.serve.disagg.api_router import (
|
||||
attach_router as attach_disagg_router,
|
||||
)
|
||||
|
||||
attach_disagg_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.rlhf.api_router import (
|
||||
attach_router as attach_rlhf_router,
|
||||
)
|
||||
|
||||
attach_rlhf_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.elastic_ep.api_router import (
|
||||
attach_router as elastic_ep_attach_router,
|
||||
)
|
||||
|
||||
elastic_ep_attach_router(app)
|
||||
|
||||
if "transcription" in supported_tasks:
|
||||
from vllm.entrypoints.openai.speech_to_text.api_router import (
|
||||
attach_router as register_speech_to_text_api_router,
|
||||
)
|
||||
|
||||
register_speech_to_text_api_router(app)
|
||||
|
||||
if "realtime" in supported_tasks:
|
||||
from vllm.entrypoints.openai.realtime.api_router import (
|
||||
attach_router as register_realtime_api_router,
|
||||
)
|
||||
|
||||
register_realtime_api_router(app)
|
||||
|
||||
if any(task in POOLING_TASKS for task in supported_tasks):
|
||||
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||
|
||||
register_pooling_api_routers(app, supported_tasks)
|
||||
|
||||
app.root_path = args.root_path
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
app.exception_handler(HTTPException)(http_exception_handler)
|
||||
app.exception_handler(RequestValidationError)(validation_exception_handler)
|
||||
|
||||
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
|
||||
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
|
||||
from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, tokens=tokens)
|
||||
|
||||
if args.enable_request_id_headers:
|
||||
from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware
|
||||
|
||||
app.add_middleware(XRequestIdMiddleware)
|
||||
|
||||
# Add scaling middleware to check for scaling state
|
||||
app.add_middleware(ScalingMiddleware)
|
||||
|
||||
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
||||
logger.warning(
|
||||
"CAUTION: Enabling log response in the API Server. "
|
||||
"This can include sensitive information and should be "
|
||||
"avoided in production."
|
||||
)
|
||||
app.middleware("http")(log_response)
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
imported = getattr(importlib.import_module(module_path), object_name)
|
||||
if inspect.isclass(imported):
|
||||
app.add_middleware(imported) # type: ignore[arg-type]
|
||||
elif inspect.iscoroutinefunction(imported):
|
||||
app.middleware("http")(imported)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid middleware {middleware}. Must be a function or a class."
|
||||
)
|
||||
|
||||
app = sagemaker_standards_bootstrap(app)
|
||||
return app
|
||||
|
||||
|
||||
async def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
supported_tasks: tuple["SupportedTask", ...] | None = None,
|
||||
) -> None:
|
||||
vllm_config = engine_client.vllm_config
|
||||
if supported_tasks is None:
|
||||
warnings.warn(
|
||||
"The 'supported_tasks' parameter was not provided to "
|
||||
"init_app_state and will be required in a future version. "
|
||||
"Please pass 'supported_tasks' explicitly.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
supported_tasks = _FALLBACK_SUPPORTED_TASKS
|
||||
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
|
||||
]
|
||||
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
state.vllm_config = vllm_config
|
||||
state.args = args
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
# Merge default_mm_loras into the static lora_modules
|
||||
default_mm_loras = (
|
||||
vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None
|
||||
else {}
|
||||
)
|
||||
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
)
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
|
||||
if "generate" in supported_tasks:
|
||||
from vllm.entrypoints.openai.generate.api_router import init_generate_state
|
||||
|
||||
await init_generate_state(
|
||||
engine_client, state, args, request_logger, supported_tasks
|
||||
)
|
||||
|
||||
if "transcription" in supported_tasks:
|
||||
from vllm.entrypoints.openai.speech_to_text.api_router import (
|
||||
init_transcription_state,
|
||||
)
|
||||
|
||||
init_transcription_state(
|
||||
engine_client, state, args, request_logger, supported_tasks
|
||||
)
|
||||
|
||||
if "realtime" in supported_tasks:
|
||||
from vllm.entrypoints.openai.realtime.api_router import init_realtime_state
|
||||
|
||||
init_realtime_state(engine_client, state, args, request_logger, supported_tasks)
|
||||
|
||||
if any(task in POOLING_TASKS for task in supported_tasks):
|
||||
from vllm.entrypoints.pooling import init_pooling_state
|
||||
|
||||
init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
state.server_load_metrics = 0
|
||||
|
||||
|
||||
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
||||
family = socket.AF_INET
|
||||
if is_valid_ipv6_address(addr[0]):
|
||||
family = socket.AF_INET6
|
||||
|
||||
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
sock.bind(addr)
|
||||
|
||||
return sock
|
||||
|
||||
|
||||
def create_server_unix_socket(path: str) -> socket.socket:
|
||||
sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
|
||||
sock.bind(path)
|
||||
return sock
|
||||
|
||||
|
||||
def validate_api_server_args(args):
|
||||
valid_tool_parses = ToolParserManager.list_registered()
|
||||
if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
|
||||
raise KeyError(
|
||||
f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valid_tool_parses)} }})"
|
||||
)
|
||||
|
||||
valid_reasoning_parsers = ReasoningParserManager.list_registered()
|
||||
if (
|
||||
reasoning_parser := args.structured_outputs_config.reasoning_parser
|
||||
) and reasoning_parser not in valid_reasoning_parsers:
|
||||
raise KeyError(
|
||||
f"invalid reasoning parser: {reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
|
||||
)
|
||||
|
||||
|
||||
@instrument(span_name="API server setup")
|
||||
def setup_server(args):
|
||||
"""Validate API server args, set up signal handler, create socket
|
||||
ready to serve."""
|
||||
|
||||
log_version_and_model(logger, VLLM_VERSION, args.model)
|
||||
log_non_default_args(args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
|
||||
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
|
||||
|
||||
validate_api_server_args(args)
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
if args.uds:
|
||||
sock = create_server_unix_socket(args.uds)
|
||||
else:
|
||||
sock_addr = (args.host or "", args.port)
|
||||
sock = create_server_socket(sock_addr)
|
||||
|
||||
# workaround to avoid footguns where uvicorn drops requests with too
|
||||
# many concurrent requests active
|
||||
set_ulimit()
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
if args.uds:
|
||||
listen_address = f"unix:{args.uds}"
|
||||
else:
|
||||
addr, port = sock_addr
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
|
||||
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
|
||||
return listen_address, sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
"""Run a single-worker API server."""
|
||||
|
||||
# Add process-specific prefix to stdout and stderr.
|
||||
decorate_logs("APIServer")
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||
|
||||
|
||||
async def run_server_worker(
|
||||
listen_address, sock, args, client_config=None, **uvicorn_kwargs
|
||||
) -> None:
|
||||
"""Run a single API server worker."""
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
|
||||
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
|
||||
|
||||
# Get uvicorn log config (from file or with endpoint filter)
|
||||
log_config = get_uvicorn_log_config(args)
|
||||
if log_config is not None:
|
||||
uvicorn_kwargs["log_config"] = log_config
|
||||
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
client_config=client_config,
|
||||
) as engine_client:
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
app = build_app(args, supported_tasks)
|
||||
await init_app_state(engine_client, app.state, args, supported_tasks)
|
||||
|
||||
logger.info(
|
||||
"Starting vLLM API server %d on %s",
|
||||
engine_client.vllm_config.parallel_config._api_process_rank,
|
||||
listen_address,
|
||||
)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
# NOTE: When the 'disable_uvicorn_access_log' value is True,
|
||||
# no access log will be output.
|
||||
access_log=not args.disable_uvicorn_access_log,
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
ssl_ciphers=args.ssl_ciphers,
|
||||
h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
|
||||
h11_max_header_count=args.h11_max_header_count,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
try:
|
||||
await shutdown_task
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
|
||||
# entrypoints.
|
||||
cli_env_setup()
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server."
|
||||
)
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
2
vllm/entrypoints/openai/chat_completion/__init__.py
Normal file
2
vllm/entrypoints/openai/chat_completion/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
108
vllm/entrypoints/openai/chat_completion/api_router.py
Normal file
108
vllm/entrypoints/openai/chat_completion/api_router.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.orca_metrics import metrics_header
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat | None:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
|
||||
metrics_header_format = raw_request.headers.get(
|
||||
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
|
||||
)
|
||||
handler = chat(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Chat Completions API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_chat_completion(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(),
|
||||
headers=metrics_header(metrics_header_format),
|
||||
)
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions/render",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
response_model=list,
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request):
|
||||
"""Render chat completion request and return conversation and engine
|
||||
prompts without generating."""
|
||||
handler = chat(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Chat Completions API"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler.render_chat_request(request)
|
||||
except Exception as e:
|
||||
result = handler.create_error_response(e)
|
||||
|
||||
if isinstance(result, ErrorResponse):
|
||||
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
733
vllm/entrypoints/openai/chat_completion/protocol.py
Normal file
733
vllm/entrypoints/openai/chat_completion/protocol.py
Normal file
@@ -0,0 +1,733 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import json
|
||||
import time
|
||||
from typing import Annotated, Any, ClassVar, Literal
|
||||
|
||||
import torch
|
||||
from openai.types.chat.chat_completion_audio import (
|
||||
ChatCompletionAudio as OpenAIChatCompletionAudio,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.utils import replace
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
AnyResponseFormat,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
LegacyStructuralTagResponseFormat,
|
||||
OpenAIBaseModel,
|
||||
StreamOptions,
|
||||
StructuralTagResponseFormat,
|
||||
ToolCall,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
refusal: str | None = None
|
||||
annotations: OpenAIAnnotation | None = None
|
||||
audio: OpenAIChatCompletionAudio | None = None
|
||||
function_call: FunctionCall | None = None
|
||||
tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||
token: str
|
||||
logprob: float = -9999.0
|
||||
bytes: list[int] | None = None
|
||||
|
||||
|
||||
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
||||
# Workaround: redefine fields name cache so that it's not
|
||||
# shared with the super class.
|
||||
field_names: ClassVar[set[str] | None] = None
|
||||
top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||
content: list[ChatCompletionLogProbsContent] | None = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: ChatCompletionLogProbs | None = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: str | None = "stop"
|
||||
# not part of the OpenAI spec but included in vLLM for legacy reasons
|
||||
stop_reason: int | str | None = None
|
||||
# not part of the OpenAI spec but is useful for tracing the tokens
|
||||
# in agent scenarios
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[ChatCompletionResponseChoice]
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
|
||||
system_fingerprint: str | None = None
|
||||
usage: UsageInfo
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||
prompt_token_ids: list[int] | None = None
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None, description="KVTransfer parameters."
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: ChatCompletionLogProbs | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
# not part of the OpenAI spec but for tracing the tokens
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[ChatCompletionResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
# not part of the OpenAI spec but for tracing the tokens
|
||||
prompt_token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class ChatCompletionToolsParam(OpenAIBaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||
function: ChatCompletionNamedFunction
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
model: str | None = None
|
||||
frequency_penalty: float | None = 0.0
|
||||
logit_bias: dict[str, float] | None = None
|
||||
logprobs: bool | None = False
|
||||
top_logprobs: int | None = 0
|
||||
max_tokens: int | None = Field(
|
||||
default=None,
|
||||
deprecated="max_tokens is deprecated in favor of "
|
||||
"the max_completion_tokens field",
|
||||
)
|
||||
max_completion_tokens: int | None = None
|
||||
n: int | None = 1
|
||||
presence_penalty: float | None = 0.0
|
||||
response_format: AnyResponseFormat | None = None
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: str | list[str] | None = []
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[ChatCompletionToolsParam] | None = None
|
||||
tool_choice: (
|
||||
Literal["none"]
|
||||
| Literal["auto"]
|
||||
| Literal["required"]
|
||||
| ChatCompletionNamedToolChoiceParam
|
||||
| None
|
||||
) = "none"
|
||||
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
||||
include_reasoning: bool = True
|
||||
parallel_tool_calls: bool | None = True
|
||||
|
||||
# NOTE this will be ignored by vLLM
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:chat-completion-sampling-params]
|
||||
use_beam_search: bool = False
|
||||
top_k: int | None = None
|
||||
min_p: float | None = None
|
||||
repetition_penalty: float | None = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: list[int] | None = []
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = (
|
||||
None
|
||||
)
|
||||
prompt_logprobs: int | None = None
|
||||
allowed_token_ids: list[int] | None = None
|
||||
bad_words: list[str] = Field(default_factory=list)
|
||||
# --8<-- [end:chat-completion-sampling-params]
|
||||
|
||||
# --8<-- [start:chat-completion-extra-params]
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the new message will be prepended with the last message "
|
||||
"if they belong to the same role."
|
||||
),
|
||||
)
|
||||
add_generation_prompt: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
'This allows you to "prefill" part of the model\'s response for it. '
|
||||
"Cannot be used at the same time as `add_generation_prompt`."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
documents: list[dict[str, str]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A list of dicts representing documents that will be accessible to "
|
||||
"the model if it is performing RAG (retrieval-augmented generation)."
|
||||
" If the template does not support RAG, this argument will have no "
|
||||
"effect. We recommend that each document should be a dict containing "
|
||||
'"title" and "text" keys.'
|
||||
),
|
||||
)
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
structured_outputs: StructuredOutputsParams | None = Field(
|
||||
default=None,
|
||||
description="Additional kwargs for structured outputs",
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
|
||||
return_tokens_as_token_ids: bool | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified with 'logprobs', tokens are represented "
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified."
|
||||
),
|
||||
)
|
||||
return_token_ids: bool | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the result will include token IDs alongside the "
|
||||
"generated text. In streaming mode, prompt_token_ids is included "
|
||||
"only in the first chunk, and token_ids contains the delta tokens "
|
||||
"for each chunk. This is useful for debugging or when you "
|
||||
"need to map generated text back to input tokens."
|
||||
),
|
||||
)
|
||||
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with (list of) string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
|
||||
# --8<-- [end:chat-completion-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template or default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
continue_final_message=self.continue_final_message,
|
||||
documents=self.documents,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
if self.max_completion_tokens is not None:
|
||||
max_output_tokens: int | None = self.max_completion_tokens
|
||||
max_output_tokens_param = "max_completion_tokens"
|
||||
else:
|
||||
max_output_tokens = self.max_tokens
|
||||
max_output_tokens_param = "max_tokens"
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=max_output_tokens or 0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
needs_detokenization=bool(self.echo and not self.return_token_ids),
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param=max_output_tokens_param,
|
||||
)
|
||||
|
||||
# Default sampling parameters for chat completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 0,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self, max_tokens: int, default_sampling_params: dict
|
||||
) -> BeamSearchParams:
|
||||
n = self.n if self.n is not None else 1
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
max_tokens: int,
|
||||
default_sampling_params: dict,
|
||||
) -> SamplingParams:
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||
)
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||
)
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
|
||||
)
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
|
||||
response_format = self.response_format
|
||||
if response_format is not None:
|
||||
structured_outputs_kwargs = dict[str, Any]()
|
||||
|
||||
# Set structured output params for response format
|
||||
if response_format.type == "json_object":
|
||||
structured_outputs_kwargs["json_object"] = True
|
||||
elif response_format.type == "json_schema":
|
||||
json_schema = response_format.json_schema
|
||||
assert json_schema is not None
|
||||
structured_outputs_kwargs["json"] = json_schema.json_schema
|
||||
elif response_format.type == "structural_tag":
|
||||
structural_tag = response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
structural_tag,
|
||||
(
|
||||
LegacyStructuralTagResponseFormat,
|
||||
StructuralTagResponseFormat,
|
||||
),
|
||||
)
|
||||
s_tag_obj = structural_tag.model_dump(by_alias=True)
|
||||
structured_outputs_kwargs["structural_tag"] = json.dumps(s_tag_obj)
|
||||
|
||||
# If structured outputs wasn't already enabled,
|
||||
# we must enable it for these features to work
|
||||
if len(structured_outputs_kwargs) > 0:
|
||||
self.structured_outputs = (
|
||||
StructuredOutputsParams(**structured_outputs_kwargs)
|
||||
if self.structured_outputs is None
|
||||
else replace(self.structured_outputs, **structured_outputs_kwargs)
|
||||
)
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
# Pass in kv_transfer_params via extra_args
|
||||
extra_args["kv_transfer_params"] = self.kv_transfer_params
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
structured_outputs=self.structured_outputs,
|
||||
logit_bias=self.logit_bias,
|
||||
bad_words=self.bad_words,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter="stream_options",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.",
|
||||
parameter="prompt_logprobs",
|
||||
)
|
||||
|
||||
if prompt_logprobs < 0 and prompt_logprobs != -1:
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` must be a positive value or -1.",
|
||||
parameter="prompt_logprobs",
|
||||
value=prompt_logprobs,
|
||||
)
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0 and top_logprobs != -1:
|
||||
raise VLLMValidationError(
|
||||
"`top_logprobs` must be a positive value or -1.",
|
||||
parameter="top_logprobs",
|
||||
value=top_logprobs,
|
||||
)
|
||||
|
||||
if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
|
||||
raise VLLMValidationError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true.",
|
||||
parameter="top_logprobs",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_structured_outputs_count(cls, data):
|
||||
if isinstance(data, ValueError):
|
||||
raise data
|
||||
|
||||
if data.get("structured_outputs", None) is None:
|
||||
return data
|
||||
|
||||
structured_outputs_kwargs = data["structured_outputs"]
|
||||
# structured_outputs may arrive as a dict (from JSON/raw kwargs) or
|
||||
# as a StructuredOutputsParams dataclass instance.
|
||||
is_dataclass = isinstance(structured_outputs_kwargs, StructuredOutputsParams)
|
||||
count = sum(
|
||||
(
|
||||
getattr(structured_outputs_kwargs, k, None)
|
||||
if is_dataclass
|
||||
else structured_outputs_kwargs.get(k)
|
||||
)
|
||||
is not None
|
||||
for k in ("json", "regex", "choice")
|
||||
)
|
||||
# you can only use one kind of constraints for structured outputs
|
||||
if count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of constraints for structured "
|
||||
"outputs ('json', 'regex' or 'choice')."
|
||||
)
|
||||
# you can only either use structured outputs or tools, not both
|
||||
if count > 1 and data.get("tool_choice", "none") not in (
|
||||
"none",
|
||||
"auto",
|
||||
"required",
|
||||
):
|
||||
raise ValueError(
|
||||
"You can only either use constraints for structured outputs "
|
||||
"or tools, not both."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tool_usage(cls, data):
|
||||
# if "tool_choice" is not specified but tools are provided,
|
||||
# default to "auto" tool_choice
|
||||
if "tool_choice" not in data and data.get("tools"):
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
# if "tool_choice" is "none" -- no validation is needed for tools
|
||||
if "tool_choice" in data and data["tool_choice"] == "none":
|
||||
return data
|
||||
|
||||
# if "tool_choice" is specified -- validation
|
||||
if "tool_choice" in data and data["tool_choice"] is not None:
|
||||
# ensure that if "tool choice" is specified, tools are present
|
||||
if "tools" not in data or data["tools"] is None:
|
||||
raise ValueError("When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
# make sure that tool choice is either a named tool
|
||||
# OR that it's set to "auto" or "required"
|
||||
if data["tool_choice"] not in ["auto", "required"] and not isinstance(
|
||||
data["tool_choice"], dict
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid value for `tool_choice`: {data['tool_choice']}! "
|
||||
'Only named tools, "none", "auto" or "required" '
|
||||
"are supported."
|
||||
)
|
||||
|
||||
# if tool_choice is "required" but the "tools" list is empty,
|
||||
# override the data to behave like "none" to align with
|
||||
# OpenAI’s behavior.
|
||||
if (
|
||||
data["tool_choice"] == "required"
|
||||
and isinstance(data["tools"], list)
|
||||
and len(data["tools"]) == 0
|
||||
):
|
||||
data["tool_choice"] = "none"
|
||||
del data["tools"]
|
||||
return data
|
||||
|
||||
# ensure that if "tool_choice" is specified as an object,
|
||||
# it matches a valid tool
|
||||
correct_usage_message = (
|
||||
'Correct usage: `{"type": "function",'
|
||||
' "function": {"name": "my_function"}}`'
|
||||
)
|
||||
if isinstance(data["tool_choice"], dict):
|
||||
valid_tool = False
|
||||
function = data["tool_choice"].get("function")
|
||||
if not isinstance(function, dict):
|
||||
raise ValueError(
|
||||
f"Invalid value for `function`: `{function}` in "
|
||||
f"`tool_choice`! {correct_usage_message}"
|
||||
)
|
||||
if "name" not in function:
|
||||
raise ValueError(
|
||||
f"Expected field `name` in `function` in "
|
||||
f"`tool_choice`! {correct_usage_message}"
|
||||
)
|
||||
function_name = function["name"]
|
||||
if not isinstance(function_name, str) or len(function_name) == 0:
|
||||
raise ValueError(
|
||||
f"Invalid `name` in `function`: `{function_name}`"
|
||||
f" in `tool_choice`! {correct_usage_message}"
|
||||
)
|
||||
for tool in data["tools"]:
|
||||
if tool["function"]["name"] == function_name:
|
||||
valid_tool = True
|
||||
break
|
||||
if not valid_tool:
|
||||
raise ValueError(
|
||||
"The tool specified in `tool_choice` does not match any"
|
||||
" of the specified `tools`"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||
raise ValueError(
|
||||
"Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None and (
|
||||
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_system_message_content_type(cls, data):
|
||||
"""Warn if system messages contain non-text content.
|
||||
|
||||
According to OpenAI API spec, system messages can only be of type
|
||||
'text'. We log a warning instead of rejecting to avoid breaking
|
||||
users who intentionally send multimodal system messages.
|
||||
See: https://platform.openai.com/docs/api-reference/chat/create#chat_create-messages-system_message
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
messages = data.get("messages", [])
|
||||
for msg in messages:
|
||||
# Check if this is a system message
|
||||
if isinstance(msg, dict) and msg.get("role") == "system":
|
||||
content = msg.get("content")
|
||||
|
||||
# If content is a list (multimodal format)
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = part.get("type")
|
||||
# Infer type when 'type' field is not explicit
|
||||
if part_type is None:
|
||||
if "image_url" in part or "image_pil" in part:
|
||||
part_type = "image_url"
|
||||
elif "image_embeds" in part:
|
||||
part_type = "image_embeds"
|
||||
elif "audio_url" in part:
|
||||
part_type = "audio_url"
|
||||
elif "input_audio" in part:
|
||||
part_type = "input_audio"
|
||||
elif "audio_embeds" in part:
|
||||
part_type = "audio_embeds"
|
||||
elif "video_url" in part:
|
||||
part_type = "video_url"
|
||||
|
||||
# Warn about non-text content in system messages
|
||||
if part_type and part_type != "text":
|
||||
logger.warning_once(
|
||||
"System messages should only contain text "
|
||||
"content according to the OpenAI API spec. "
|
||||
"Found content type: '%s'.",
|
||||
part_type,
|
||||
)
|
||||
|
||||
return data
|
||||
1981
vllm/entrypoints/openai/chat_completion/serving.py
Normal file
1981
vllm/entrypoints/openai/chat_completion/serving.py
Normal file
File diff suppressed because it is too large
Load Diff
171
vllm/entrypoints/openai/chat_completion/stream_harmony.py
Normal file
171
vllm/entrypoints/openai/chat_completion/stream_harmony.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Harmony-specific streaming delta extraction for chat completions.
|
||||
|
||||
This module handles the extraction of DeltaMessage objects from
|
||||
harmony parser state during streaming chat completions.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from openai_harmony import StreamableParser
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
|
||||
|
||||
class TokenState(NamedTuple):
|
||||
channel: str | None
|
||||
recipient: str | None
|
||||
text: str
|
||||
|
||||
|
||||
def extract_harmony_streaming_delta(
|
||||
harmony_parser: StreamableParser,
|
||||
token_states: list[TokenState],
|
||||
prev_recipient: str | None,
|
||||
include_reasoning: bool,
|
||||
) -> tuple[DeltaMessage | None, bool]:
|
||||
"""
|
||||
Extract a DeltaMessage from harmony parser state during streaming.
|
||||
|
||||
Args:
|
||||
harmony_parser: The StreamableParser instance tracking parse state
|
||||
token_states: List of TokenState tuples for each token
|
||||
prev_recipient: Previous recipient for detecting tool call transitions
|
||||
include_reasoning: Whether to include reasoning content
|
||||
|
||||
Returns:
|
||||
A tuple of (DeltaMessage or None, tools_streamed_flag)
|
||||
"""
|
||||
|
||||
if not token_states:
|
||||
return None, False
|
||||
|
||||
tools_streamed = False
|
||||
|
||||
# Group consecutive tokens with same channel/recipient
|
||||
groups: list[TokenState] = []
|
||||
|
||||
current_channel = token_states[0].channel
|
||||
current_recipient = token_states[0].recipient
|
||||
current_text = token_states[0].text
|
||||
|
||||
for i in range(1, len(token_states)):
|
||||
state = token_states[i]
|
||||
if state.channel == current_channel and state.recipient == current_recipient:
|
||||
current_text += state.text
|
||||
else:
|
||||
groups.append(TokenState(current_channel, current_recipient, current_text))
|
||||
current_channel = state.channel
|
||||
current_recipient = state.recipient
|
||||
current_text = state.text
|
||||
|
||||
groups.append(TokenState(current_channel, current_recipient, current_text))
|
||||
|
||||
# Process each group and create delta messages
|
||||
delta_message = None
|
||||
combined_content = ""
|
||||
combined_reasoning = ""
|
||||
tool_messages = []
|
||||
content_encountered = False
|
||||
|
||||
# Calculate base_index once before the loop
|
||||
# This counts completed tool calls in messages
|
||||
base_index = 0
|
||||
for msg in harmony_parser.messages:
|
||||
if (
|
||||
(msg.channel == "commentary" or msg.channel == "analysis")
|
||||
and msg.recipient
|
||||
and msg.recipient.startswith("functions.")
|
||||
):
|
||||
base_index += 1
|
||||
|
||||
# If there's an ongoing tool call from previous chunk,
|
||||
# the next new tool call starts at base_index + 1
|
||||
if prev_recipient and prev_recipient.startswith("functions."):
|
||||
next_tool_index = base_index + 1
|
||||
# Ongoing call is at base_index
|
||||
ongoing_tool_index = base_index
|
||||
else:
|
||||
# No ongoing call, next new call is at base_index
|
||||
next_tool_index = base_index
|
||||
ongoing_tool_index = None
|
||||
|
||||
for group in groups:
|
||||
if group.channel == "final":
|
||||
combined_content += group.text
|
||||
content_encountered = True
|
||||
elif (
|
||||
(group.channel == "commentary" or group.channel == "analysis")
|
||||
and group.recipient
|
||||
and group.recipient.startswith("functions.")
|
||||
):
|
||||
opened_new_call = False
|
||||
if prev_recipient != group.recipient:
|
||||
# New tool call - emit the opening message
|
||||
tool_name = group.recipient.split("functions.", 1)[1]
|
||||
tool_messages.append(
|
||||
DeltaToolCall(
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_name,
|
||||
arguments="",
|
||||
),
|
||||
index=next_tool_index,
|
||||
)
|
||||
)
|
||||
opened_new_call = True
|
||||
prev_recipient = group.recipient
|
||||
# Increment for subsequent new tool calls
|
||||
next_tool_index += 1
|
||||
|
||||
if group.text:
|
||||
# Stream arguments for the ongoing tool call
|
||||
if opened_new_call:
|
||||
# Just opened in this group
|
||||
tool_call_index = next_tool_index - 1
|
||||
else:
|
||||
# Continuing from previous chunk
|
||||
# If ongoing_tool_index is None here, it means
|
||||
# we're continuing a call but prev_recipient
|
||||
# wasn't a function. Use base_index.
|
||||
tool_call_index = (
|
||||
ongoing_tool_index
|
||||
if ongoing_tool_index is not None
|
||||
else base_index
|
||||
)
|
||||
tool_messages.append(
|
||||
DeltaToolCall(
|
||||
index=tool_call_index,
|
||||
function=DeltaFunctionCall(arguments=group.text),
|
||||
)
|
||||
)
|
||||
elif group.channel == "commentary" and group.recipient is None:
|
||||
# Tool call preambles meant to be shown to the user
|
||||
combined_content += group.text
|
||||
content_encountered = True
|
||||
elif group.channel == "analysis" and include_reasoning:
|
||||
combined_reasoning += group.text
|
||||
|
||||
# Combine all non-empty fields into a single message
|
||||
if content_encountered or combined_reasoning or tool_messages:
|
||||
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
|
||||
if content_encountered:
|
||||
delta_kwargs["content"] = combined_content
|
||||
if combined_reasoning:
|
||||
delta_kwargs["reasoning"] = combined_reasoning
|
||||
if tool_messages:
|
||||
delta_kwargs["tool_calls"] = tool_messages
|
||||
tools_streamed = True
|
||||
delta_message = DeltaMessage(**delta_kwargs)
|
||||
else:
|
||||
delta_message = None
|
||||
|
||||
return delta_message, tools_streamed
|
||||
372
vllm/entrypoints/openai/cli_args.py
Normal file
372
vllm/entrypoints/openai/cli_args.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import config
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
validate_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.constants import (
|
||||
H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: str | Sequence[str] | None,
|
||||
option_string: str | None = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: list[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, ""]: # Skip if item is None or empty string
|
||||
continue
|
||||
if "=" in item and "," not in item: # Old format: name=path
|
||||
name, path = item.split("=")
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
@config
|
||||
class BaseFrontendArgs:
|
||||
"""Base arguments for the OpenAI-compatible frontend server.
|
||||
|
||||
This base class does not include host, port, and server-specific arguments
|
||||
like SSL, CORS, and HTTP server settings. Those arguments are added by
|
||||
the subclasses.
|
||||
"""
|
||||
|
||||
lora_modules: list[LoRAModulePath] | None = None
|
||||
"""LoRA modules configurations in either 'name=path' format or JSON format
|
||||
or JSON list format. Example (old format): `'name=path'` Example (new
|
||||
format): `{\"name\": \"name\", \"path\": \"lora_path\",
|
||||
\"base_model_name\": \"id\"}`"""
|
||||
chat_template: str | None = None
|
||||
"""The file path to the chat template, or the template in single-line form
|
||||
for the specified model."""
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format to render message content within a chat template.
|
||||
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to
|
||||
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
trust_request_chat_template: bool = False
|
||||
"""Whether to trust the chat template provided in the request. If False,
|
||||
the server will always use the chat template specified by `--chat-template`
|
||||
or the ones from tokenizer."""
|
||||
default_chat_template_kwargs: dict[str, Any] | None = None
|
||||
"""Default keyword arguments to pass to the chat template renderer.
|
||||
These will be merged with request-level chat_template_kwargs,
|
||||
with request values taking precedence. Useful for setting default
|
||||
behavior for reasoning models. Example: '{"enable_thinking": false}'
|
||||
to disable thinking mode by default for Qwen3/DeepSeek models."""
|
||||
response_role: str = "assistant"
|
||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||
return_tokens_as_token_ids: bool = False
|
||||
"""When `--max-logprobs` is specified, represents single tokens as
|
||||
strings of the form 'token_id:{token_id}' so that tokens that are not
|
||||
JSON-encodable can be identified."""
|
||||
disable_frontend_multiprocessing: bool = False
|
||||
"""If specified, will run the OpenAI frontend server in the same process as
|
||||
the model serving engine."""
|
||||
enable_auto_tool_choice: bool = False
|
||||
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
|
||||
to specify which parser to use."""
|
||||
exclude_tools_when_tool_choice_none: bool = False
|
||||
"""If specified, exclude tool definitions in prompts when
|
||||
tool_choice='none'."""
|
||||
tool_call_parser: str | None = None
|
||||
"""Select the tool call parser depending on the model that you're using.
|
||||
This is used to parse the model-generated tool call into OpenAI API format.
|
||||
Required for `--enable-auto-tool-choice`. You can choose any option from
|
||||
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
|
||||
tool_parser_plugin: str = ""
|
||||
"""Special the tool parser plugin write to parse the model-generated tool
|
||||
into OpenAI API format, the name register in this plugin can be used in
|
||||
`--tool-call-parser`."""
|
||||
tool_server: str | None = None
|
||||
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
|
||||
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
|
||||
purpose."""
|
||||
log_config_file: str | None = envs.VLLM_LOGGING_CONFIG_PATH
|
||||
"""Path to logging config JSON file for both vllm and uvicorn"""
|
||||
max_log_len: int | None = None
|
||||
"""Max number of prompt characters or prompt ID numbers being printed in
|
||||
log. The default of None means unlimited."""
|
||||
enable_prompt_tokens_details: bool = False
|
||||
"""If set to True, enable prompt_tokens_details in usage."""
|
||||
enable_server_load_tracking: bool = False
|
||||
"""If set to True, enable tracking server_load_metrics in the app state."""
|
||||
enable_force_include_usage: bool = False
|
||||
"""If set to True, including usage on every request."""
|
||||
enable_tokenizer_info_endpoint: bool = False
|
||||
"""Enable the `/tokenizer_info` endpoint. May expose chat
|
||||
templates and other tokenizer configuration."""
|
||||
enable_log_outputs: bool = False
|
||||
"""If set to True, log model outputs (generations).
|
||||
Requires --enable-log-requests."""
|
||||
enable_log_deltas: bool = True
|
||||
"""If set to False, output deltas will not be logged. Relevant only if
|
||||
--enable-log-outputs is set.
|
||||
"""
|
||||
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
|
||||
"""If set to True, log the stack trace of error responses"""
|
||||
tokens_only: bool = False
|
||||
"""
|
||||
If set to True, only enable the Tokens In<>Out endpoint.
|
||||
This is intended for use in a Disaggregated Everything setup.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _customize_cli_kwargs(
|
||||
cls,
|
||||
frontend_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Customize argparse kwargs before arguments are registered.
|
||||
|
||||
Subclasses should override this and call
|
||||
``super()._customize_cli_kwargs(frontend_kwargs)`` first.
|
||||
"""
|
||||
# Special case: default_chat_template_kwargs needs json.loads type
|
||||
frontend_kwargs["default_chat_template_kwargs"]["type"] = json.loads
|
||||
|
||||
# Special case: LoRA modules need custom parser action and
|
||||
# optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
|
||||
|
||||
# Special case: Tool call parser shows built-in options.
|
||||
valid_tool_parsers = list(ToolParserManager.list_registered())
|
||||
parsers_str = ",".join(valid_tool_parsers)
|
||||
frontend_kwargs["tool_call_parser"]["metavar"] = (
|
||||
f"{{{parsers_str}}} or name registered in --tool-parser-plugin"
|
||||
)
|
||||
return frontend_kwargs
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Register CLI arguments for this frontend class.
|
||||
|
||||
Subclasses should override ``_customize_cli_kwargs`` instead of
|
||||
this method so that base-class postprocessing is always applied.
|
||||
"""
|
||||
from vllm.engine.arg_utils import get_kwargs
|
||||
|
||||
frontend_kwargs = get_kwargs(cls)
|
||||
frontend_kwargs = cls._customize_cli_kwargs(frontend_kwargs)
|
||||
|
||||
group_name = cls.__name__.replace("Args", "")
|
||||
frontend_group = parser.add_argument_group(
|
||||
title=group_name,
|
||||
description=cls.__doc__,
|
||||
)
|
||||
for key, value in frontend_kwargs.items():
|
||||
extra_flags = value.pop("flags", [])
|
||||
frontend_group.add_argument(
|
||||
*extra_flags, f"--{key.replace('_', '-')}", **value
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@config
|
||||
class FrontendArgs(BaseFrontendArgs):
|
||||
"""Arguments for the OpenAI-compatible frontend server."""
|
||||
|
||||
host: str | None = None
|
||||
"""Host name."""
|
||||
port: int = 8000
|
||||
"""Port number."""
|
||||
uds: str | None = None
|
||||
"""Unix domain socket path. If set, host and port arguments are ignored."""
|
||||
uvicorn_log_level: Literal[
|
||||
"critical", "error", "warning", "info", "debug", "trace"
|
||||
] = "info"
|
||||
"""Log level for uvicorn."""
|
||||
disable_uvicorn_access_log: bool = False
|
||||
"""Disable uvicorn access log."""
|
||||
disable_access_log_for_endpoints: str | None = None
|
||||
"""Comma-separated list of endpoint paths to exclude from uvicorn access
|
||||
logs. This is useful to reduce log noise from high-frequency endpoints
|
||||
like health checks. Example: "/health,/metrics,/ping".
|
||||
When set, access logs for requests to these paths will be suppressed
|
||||
while keeping logs for other endpoints."""
|
||||
allow_credentials: bool = False
|
||||
"""Allow credentials."""
|
||||
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed origins."""
|
||||
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed methods."""
|
||||
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed headers."""
|
||||
api_key: list[str] | None = None
|
||||
"""If provided, the server will require one of these keys to be presented in
|
||||
the header."""
|
||||
ssl_keyfile: str | None = None
|
||||
"""The file path to the SSL key file."""
|
||||
ssl_certfile: str | None = None
|
||||
"""The file path to the SSL cert file."""
|
||||
ssl_ca_certs: str | None = None
|
||||
"""The CA certificates file."""
|
||||
enable_ssl_refresh: bool = False
|
||||
"""Refresh SSL Context when SSL certificate files change"""
|
||||
ssl_cert_reqs: int = int(ssl.CERT_NONE)
|
||||
"""Whether client certificate is required (see stdlib ssl module's)."""
|
||||
ssl_ciphers: str | None = None
|
||||
"""SSL cipher suites for HTTPS (TLS 1.2 and below only).
|
||||
Example: 'ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305'"""
|
||||
root_path: str | None = None
|
||||
"""FastAPI root_path when app is behind a path based routing proxy."""
|
||||
middleware: list[str] = field(default_factory=lambda: [])
|
||||
"""Additional ASGI middleware to apply to the app. We accept multiple
|
||||
--middleware arguments. The value should be an import path. If a function
|
||||
is provided, vLLM will add it to the server using
|
||||
`@app.middleware('http')`. If a class is provided, vLLM will
|
||||
add it to the server using `app.add_middleware()`."""
|
||||
enable_request_id_headers: bool = False
|
||||
"""If specified, API server will add X-Request-Id header to responses."""
|
||||
disable_fastapi_docs: bool = False
|
||||
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
|
||||
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
|
||||
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
|
||||
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
|
||||
"""Maximum number of HTTP headers allowed in a request for h11 parser.
|
||||
Helps mitigate header abuse. Default: 256."""
|
||||
enable_offline_docs: bool = False
|
||||
"""
|
||||
Enable offline FastAPI documentation for air-gapped environments.
|
||||
Uses vendored static assets bundled with vLLM.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _customize_cli_kwargs(
|
||||
cls,
|
||||
frontend_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
frontend_kwargs = super()._customize_cli_kwargs(frontend_kwargs)
|
||||
|
||||
# Special case: allowed_origins, allowed_methods, allowed_headers all
|
||||
# need json.loads type
|
||||
# Should also remove nargs
|
||||
frontend_kwargs["allowed_origins"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_methods"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_headers"]["type"] = json.loads
|
||||
del frontend_kwargs["allowed_origins"]["nargs"]
|
||||
del frontend_kwargs["allowed_methods"]["nargs"]
|
||||
del frontend_kwargs["allowed_headers"]["nargs"]
|
||||
|
||||
# Special case: Middleware needs to append action
|
||||
frontend_kwargs["middleware"]["action"] = "append"
|
||||
frontend_kwargs["middleware"]["type"] = str
|
||||
if "nargs" in frontend_kwargs["middleware"]:
|
||||
del frontend_kwargs["middleware"]["nargs"]
|
||||
frontend_kwargs["middleware"]["default"] = []
|
||||
|
||||
# Special case: disable_access_log_for_endpoints is a single
|
||||
# comma-separated string, not a list
|
||||
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
|
||||
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
|
||||
|
||||
return frontend_kwargs
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Create the CLI argument parser used by the OpenAI API server.
|
||||
|
||||
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
|
||||
register all arguments instead of manually enumerating them here. This
|
||||
avoids code duplication and keeps the argument definitions in one place.
|
||||
"""
|
||||
parser.add_argument(
|
||||
"model_tag",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="The model tag to serve (optional if specified in config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run in headless mode. See multi-node data parallel "
|
||||
"documentation for more details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-server-count",
|
||||
"-asc",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many API server processes to run. "
|
||||
"Defaults to data_parallel_size if not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help="Read CLI options from a config file. "
|
||||
"Must be a YAML with the following options: "
|
||||
"https://docs.vllm.ai/en/latest/configuration/serve_args.html",
|
||||
)
|
||||
parser = FrontendArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser")
|
||||
if args.enable_log_outputs and not args.enable_log_requests:
|
||||
raise TypeError("Error: --enable-log-outputs requires --enable-log-requests")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server"
|
||||
)
|
||||
return make_arg_parser(parser_for_docs)
|
||||
2
vllm/entrypoints/openai/completion/__init__.py
Normal file
2
vllm/entrypoints/openai/completion/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
106
vllm/entrypoints/openai/completion/api_router.py
Normal file
106
vllm/entrypoints/openai/completion/api_router.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.completion.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.orca_metrics import metrics_header
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion | None:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
metrics_header_format = raw_request.headers.get(
|
||||
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
|
||||
)
|
||||
handler = completion(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Completions API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(),
|
||||
headers=metrics_header(metrics_header_format),
|
||||
)
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions/render",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
response_model=list,
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def render_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""render completion request and return engine prompts without generating."""
|
||||
handler = completion(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Completions API"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler.render_completion_request(request)
|
||||
except Exception as e:
|
||||
result = handler.create_error_response(e)
|
||||
|
||||
if isinstance(result, ErrorResponse):
|
||||
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
475
vllm/entrypoints/openai/completion/protocol.py
Normal file
475
vllm/entrypoints/openai/completion/protocol.py
Normal file
@@ -0,0 +1,475 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import json
|
||||
import time
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.utils import replace
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
AnyResponseFormat,
|
||||
LegacyStructuralTagResponseFormat,
|
||||
OpenAIBaseModel,
|
||||
StreamOptions,
|
||||
StructuralTagResponseFormat,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str | None = None
|
||||
prompt: (
|
||||
list[Annotated[int, Field(ge=0)]]
|
||||
| list[list[Annotated[int, Field(ge=0)]]]
|
||||
| str
|
||||
| list[str]
|
||||
| None
|
||||
) = None
|
||||
echo: bool | None = False
|
||||
frequency_penalty: float | None = 0.0
|
||||
logit_bias: dict[str, float] | None = None
|
||||
logprobs: int | None = None
|
||||
max_tokens: int | None = 16
|
||||
n: int = 1
|
||||
presence_penalty: float | None = 0.0
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: str | list[str] | None = []
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
suffix: str | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:completion-sampling-params]
|
||||
use_beam_search: bool = False
|
||||
top_k: int | None = None
|
||||
min_p: float | None = None
|
||||
repetition_penalty: float | None = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: list[int] | None = []
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = (
|
||||
None
|
||||
)
|
||||
allowed_token_ids: list[int] | None = None
|
||||
prompt_logprobs: int | None = None
|
||||
# --8<-- [end:completion-sampling-params]
|
||||
|
||||
# --8<-- [start:completion-extra-params]
|
||||
prompt_embeds: bytes | list[bytes] | None = None
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
response_format: AnyResponseFormat | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Similar to chat completion, this parameter specifies the format "
|
||||
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
|
||||
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
|
||||
),
|
||||
)
|
||||
structured_outputs: StructuredOutputsParams | None = Field(
|
||||
default=None,
|
||||
description="Additional kwargs for structured outputs",
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
|
||||
return_tokens_as_token_ids: bool | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified with 'logprobs', tokens are represented "
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified."
|
||||
),
|
||||
)
|
||||
return_token_ids: bool | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the result will include token IDs alongside the "
|
||||
"generated text. In streaming mode, prompt_token_ids is included "
|
||||
"only in the first chunk, and token_ids contains the delta tokens "
|
||||
"for each chunk. This is useful for debugging or when you "
|
||||
"need to map generated text back to input tokens."
|
||||
),
|
||||
)
|
||||
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
vllm_xargs: dict[str, str | int | float] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
|
||||
# --8<-- [end:completion-extra-params]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=self.max_tokens or 0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
needs_detokenization=bool(self.echo and not self.return_token_ids),
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_tokens",
|
||||
)
|
||||
|
||||
# Default sampling parameters for completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 0,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> BeamSearchParams:
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 1.0)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> SamplingParams:
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||
)
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||
)
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
|
||||
)
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.logprobs
|
||||
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
response_format = self.response_format
|
||||
if response_format is not None:
|
||||
structured_outputs_kwargs = dict[str, Any]()
|
||||
|
||||
# Set structured output params for response format
|
||||
if response_format.type == "json_object":
|
||||
structured_outputs_kwargs["json_object"] = True
|
||||
elif response_format.type == "json_schema":
|
||||
json_schema = response_format.json_schema
|
||||
assert json_schema is not None
|
||||
structured_outputs_kwargs["json"] = json_schema.json_schema
|
||||
elif response_format.type == "structural_tag":
|
||||
structural_tag = response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
structural_tag,
|
||||
(
|
||||
LegacyStructuralTagResponseFormat,
|
||||
StructuralTagResponseFormat,
|
||||
),
|
||||
)
|
||||
s_tag_obj = structural_tag.model_dump(by_alias=True)
|
||||
structured_outputs_kwargs["structural_tag"] = json.dumps(s_tag_obj)
|
||||
|
||||
# If structured outputs wasn't already enabled,
|
||||
# we must enable it for these features to work
|
||||
if len(structured_outputs_kwargs) > 0:
|
||||
self.structured_outputs = (
|
||||
StructuredOutputsParams(**structured_outputs_kwargs)
|
||||
if self.structured_outputs is None
|
||||
else replace(self.structured_outputs, **structured_outputs_kwargs)
|
||||
)
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
# Pass in kv_transfer_params via extra_args
|
||||
extra_args["kv_transfer_params"] = self.kv_transfer_params
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens if not echo_without_generation else 1,
|
||||
min_tokens=self.min_tokens,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
structured_outputs=self.structured_outputs,
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_structured_outputs_count(cls, data):
|
||||
if data.get("structured_outputs", None) is None:
|
||||
return data
|
||||
|
||||
structured_outputs_kwargs = data["structured_outputs"]
|
||||
# structured_outputs may arrive as a dict (from JSON/raw kwargs) or
|
||||
# as a StructuredOutputsParams dataclass instance.
|
||||
is_dataclass = isinstance(structured_outputs_kwargs, StructuredOutputsParams)
|
||||
count = sum(
|
||||
(
|
||||
getattr(structured_outputs_kwargs, k, None)
|
||||
if is_dataclass
|
||||
else structured_outputs_kwargs.get(k)
|
||||
)
|
||||
is not None
|
||||
for k in ("json", "regex", "choice")
|
||||
)
|
||||
if count > 1:
|
||||
raise VLLMValidationError(
|
||||
"You can only use one kind of constraints for structured "
|
||||
"outputs ('json', 'regex' or 'choice').",
|
||||
parameter="structured_outputs",
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.",
|
||||
parameter="prompt_logprobs",
|
||||
)
|
||||
|
||||
if prompt_logprobs < 0 and prompt_logprobs != -1:
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` must be a positive value or -1.",
|
||||
parameter="prompt_logprobs",
|
||||
value=prompt_logprobs,
|
||||
)
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise VLLMValidationError(
|
||||
"`logprobs` must be a positive value.",
|
||||
parameter="logprobs",
|
||||
value=logprobs,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter="stream_options",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_prompt_and_prompt_embeds(cls, data):
|
||||
prompt = data.get("prompt")
|
||||
prompt_embeds = data.get("prompt_embeds")
|
||||
|
||||
prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
|
||||
embeds_is_empty = prompt_embeds is None or (
|
||||
isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
|
||||
)
|
||||
|
||||
if prompt_is_empty and embeds_is_empty:
|
||||
raise ValueError(
|
||||
"Either prompt or prompt_embeds must be provided and non-empty."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None and (
|
||||
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: list[int] = Field(default_factory=list)
|
||||
token_logprobs: list[float | None] = Field(default_factory=list)
|
||||
tokens: list[str] = Field(default_factory=list)
|
||||
top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: CompletionLogProbs | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"
|
||||
),
|
||||
)
|
||||
token_ids: list[int] | None = None # For response
|
||||
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||
prompt_token_ids: list[int] | None = None # For prompt
|
||||
|
||||
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[CompletionResponseChoice]
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
|
||||
system_fingerprint: str | None = None
|
||||
usage: UsageInfo
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None, description="KVTransfer parameters."
|
||||
)
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: CompletionLogProbs | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"
|
||||
),
|
||||
)
|
||||
# not part of the OpenAI spec but for tracing the tokens
|
||||
# prompt tokens is put into choice to align with CompletionResponseChoice
|
||||
prompt_token_ids: list[int] | None = None
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class CompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[CompletionResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
681
vllm/entrypoints/openai/completion/serving.py
Normal file
681
vllm/entrypoints/openai/completion/serving.py
Normal file
@@ -0,0 +1,681 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.completion.protocol import (
|
||||
CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import (
|
||||
GenerationError,
|
||||
OpenAIServing,
|
||||
clamp_prompt_logprobs,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
mc = self.model_config
|
||||
self.override_max_tokens = (
|
||||
self.default_sampling_params.get("max_tokens")
|
||||
if mc.generation_config not in ("auto", "vllm")
|
||||
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
|
||||
)
|
||||
|
||||
async def render_completion_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> list[ProcessorInputs] | ErrorResponse:
|
||||
"""
|
||||
render completion request by validating and preprocessing inputs.
|
||||
|
||||
Returns:
|
||||
A list of engine_prompts on success,
|
||||
or an ErrorResponse on failure.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response("suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response("Echo is unsupported with prompt embeds.")
|
||||
|
||||
if request.prompt_logprobs is not None and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"prompt_logprobs is not compatible with prompt embeds."
|
||||
)
|
||||
|
||||
try:
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
)
|
||||
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
|
||||
return engine_prompts
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
result = await self.render_completion_request(request)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
engine_prompts = result
|
||||
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
logger.exception("Error preparing request components")
|
||||
return self.create_error_response(e)
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
self.override_max_tokens,
|
||||
)
|
||||
|
||||
sampling_params: SamplingParams | BeamSearchParams
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params
|
||||
)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# We do not stream the results when using beam search.
|
||||
stream = request.stream and not request.use_beam_search
|
||||
|
||||
# Streaming response
|
||||
tokenizer = self.renderer.tokenizer
|
||||
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[RequestOutput | None] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
engine_prompts: list[ProcessorInputs],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
num_cached_tokens = None
|
||||
first_iteration = True
|
||||
|
||||
stream_options = request.stream_options
|
||||
include_usage, include_continuous_usage = should_include_usage(
|
||||
stream_options, self.enable_force_include_usage
|
||||
)
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
|
||||
if first_iteration:
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
first_iteration = False
|
||||
|
||||
prompt_text = res.prompt
|
||||
if prompt_text is None:
|
||||
engine_prompt = engine_prompts[prompt_idx]
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
|
||||
# Useful when request.return_token_ids is True
|
||||
# Returning prompt token IDs shares the same logic
|
||||
# with the echo implementation.
|
||||
prompt_token_ids_to_return: list[int] | None = None
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and not has_echoed[i]:
|
||||
assert prompt_token_ids is not None
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
else:
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids,
|
||||
*output.token_ids,
|
||||
]
|
||||
out_logprobs = [
|
||||
*(prompt_logprobs or []),
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# has_echoed[i] is reused here to indicate whether
|
||||
# we have already returned the prompt token IDs.
|
||||
if not has_echoed[i] and request.return_token_ids:
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
|
||||
if (
|
||||
not delta_text
|
||||
and not delta_token_ids
|
||||
and not previous_num_tokens[i]
|
||||
):
|
||||
# Chunked prefill case, don't return empty chunks
|
||||
continue
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
self._raise_if_error(finish_reason, request_id)
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
prompt_token_ids=prompt_token_ids_to_return,
|
||||
token_ids=(
|
||||
as_list(output.token_ids)
|
||||
if request.return_token_ids
|
||||
else None
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if include_continuous_usage:
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
response_json = chunk.model_dump_json(exclude_unset=False)
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
total_prompt_tokens = sum(num_prompt_tokens)
|
||||
total_completion_tokens = sum(previous_num_tokens)
|
||||
final_usage_info = UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
)
|
||||
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens
|
||||
)
|
||||
|
||||
if include_usage:
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=final_usage_info,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_unset=False, exclude_none=True
|
||||
)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = final_usage_info
|
||||
|
||||
except GenerationError as e:
|
||||
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||
except Exception as e:
|
||||
logger.exception("Error in completion stream generator.")
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
self,
|
||||
final_res_batch: list[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
kv_transfer_params = None
|
||||
last_final_res = None
|
||||
for final_res in final_res_batch:
|
||||
last_final_res = final_res
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
|
||||
|
||||
for output in final_res.outputs:
|
||||
self._raise_if_error(output.finish_reason, request_id)
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo:
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
else:
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
prompt_token_ids=(
|
||||
prompt_token_ids if request.return_token_ids else None
|
||||
),
|
||||
token_ids=(
|
||||
as_list(output.token_ids) if request.return_token_ids else None
|
||||
),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
self.enable_prompt_tokens_details
|
||||
and last_final_res
|
||||
and last_final_res.num_cached_tokens
|
||||
):
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=last_final_res.num_cached_tokens
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
if final_res_batch:
|
||||
kv_transfer_params = final_res_batch[0].kv_transfer_params
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: TokenizerLike | None,
|
||||
initial_text_offset: int = 0,
|
||||
return_as_token_id: bool | None = None,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: list[int] = []
|
||||
out_token_logprobs: list[float | None] = []
|
||||
out_tokens: list[str] = []
|
||||
out_top_logprobs: list[dict[str, float] | None] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
should_return_as_token_id = (
|
||||
return_as_token_id
|
||||
if return_as_token_id is not None
|
||||
else self.return_tokens_as_token_ids
|
||||
)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
if should_return_as_token_id:
|
||||
token = f"token_id:{token_id}"
|
||||
else:
|
||||
if tokenizer is None:
|
||||
raise VLLMValidationError(
|
||||
"Unable to get tokenizer because "
|
||||
"`skip_tokenizer_init=True`",
|
||||
parameter="skip_tokenizer_init",
|
||||
value=True,
|
||||
)
|
||||
|
||||
token = tokenizer.decode(token_id)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append(
|
||||
{
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(
|
||||
top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
): max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
}
|
||||
)
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
2
vllm/entrypoints/openai/engine/__init__.py
Normal file
2
vllm/entrypoints/openai/engine/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
312
vllm/entrypoints/openai/engine/protocol.py
Normal file
312
vllm/entrypoints/openai/engine/protocol.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Any, ClassVar, Literal, TypeAlias
|
||||
|
||||
import regex as re
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does allow extra fields
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
# Cache class field names
|
||||
field_names: ClassVar[set[str] | None] = None
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def __log_extra_fields__(cls, data, handler):
|
||||
result = handler(data)
|
||||
if not isinstance(data, dict):
|
||||
return result
|
||||
field_names = cls.field_names
|
||||
if field_names is None:
|
||||
# Get all class field names and their potential aliases
|
||||
field_names = set()
|
||||
for field_name, field in cls.model_fields.items():
|
||||
field_names.add(field_name)
|
||||
if alias := getattr(field, "alias", None):
|
||||
field_names.add(alias)
|
||||
cls.field_names = field_names
|
||||
|
||||
# Compare against both field names and aliases
|
||||
if any(k not in field_names for k in data):
|
||||
logger.warning(
|
||||
"The following fields were present in the request but ignored: %s",
|
||||
data.keys() - field_names,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ErrorInfo(OpenAIBaseModel):
|
||||
message: str
|
||||
type: str
|
||||
param: str | None = None
|
||||
code: int
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ModelPermission(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
allow_create_engine: bool = False
|
||||
allow_sampling: bool = True
|
||||
allow_logprobs: bool = True
|
||||
allow_search_indices: bool = False
|
||||
allow_view: bool = True
|
||||
allow_fine_tuning: bool = False
|
||||
organization: str = "*"
|
||||
group: str | None = None
|
||||
is_blocking: bool = False
|
||||
|
||||
|
||||
class ModelCard(OpenAIBaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "vllm"
|
||||
root: str | None = None
|
||||
parent: str | None = None
|
||||
max_model_len: int | None = None
|
||||
permission: list[ModelPermission] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelList(OpenAIBaseModel):
|
||||
object: str = "list"
|
||||
data: list[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PromptTokenUsageInfo(OpenAIBaseModel):
|
||||
cached_tokens: int | None = None
|
||||
|
||||
|
||||
class UsageInfo(OpenAIBaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: int | None = 0
|
||||
prompt_tokens_details: PromptTokenUsageInfo | None = None
|
||||
|
||||
|
||||
class RequestResponseMetadata(BaseModel):
|
||||
request_id: str
|
||||
final_usage_info: UsageInfo | None = None
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(OpenAIBaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
# schema is the field in openai but that causes conflicts with pydantic so
|
||||
# instead use json_schema with an alias
|
||||
json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
class LegacyStructuralTag(OpenAIBaseModel):
|
||||
begin: str
|
||||
# schema is the field, but that causes conflicts with pydantic so
|
||||
# instead use structural_tag_schema with an alias
|
||||
structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
|
||||
end: str
|
||||
|
||||
|
||||
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: list[LegacyStructuralTag]
|
||||
triggers: list[str]
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(OpenAIBaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
format: Any
|
||||
|
||||
|
||||
AnyStructuralTagResponseFormat: TypeAlias = (
|
||||
LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
|
||||
)
|
||||
|
||||
|
||||
class ResponseFormat(OpenAIBaseModel):
|
||||
# type must be "json_schema", "json_object", or "text"
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: JsonSchemaResponseFormat | None = None
|
||||
|
||||
|
||||
AnyResponseFormat: TypeAlias = (
|
||||
ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
|
||||
)
|
||||
|
||||
|
||||
class StreamOptions(OpenAIBaseModel):
|
||||
include_usage: bool | None = True
|
||||
continuous_usage_stats: bool | None = False
|
||||
|
||||
|
||||
class FunctionDefinition(OpenAIBaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# extra="forbid" is a workaround to have kwargs as a field,
|
||||
# see https://github.com/pydantic/pydantic/issues/3125
|
||||
class LogitsProcessorConstructor(BaseModel):
|
||||
qualname: str
|
||||
args: list[Any] | None = None
|
||||
kwargs: dict[str, Any] | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
LogitsProcessors = list[str | LogitsProcessorConstructor]
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
processors: LogitsProcessors | None, pattern: str | None
|
||||
) -> list[Any] | None:
|
||||
if processors and pattern:
|
||||
logits_processors = []
|
||||
for processor in processors:
|
||||
qualname = processor if isinstance(processor, str) else processor.qualname
|
||||
if not re.match(pattern, qualname):
|
||||
raise ValueError(
|
||||
f"Logits processor '{qualname}' is not allowed by this "
|
||||
"server. See --logits-processor-pattern engine argument "
|
||||
"for more information."
|
||||
)
|
||||
try:
|
||||
logits_processor = resolve_obj_by_qualname(qualname)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Logits processor '{qualname}' could not be resolved: {e}"
|
||||
) from e
|
||||
if isinstance(processor, LogitsProcessorConstructor):
|
||||
logits_processor = logits_processor(
|
||||
*processor.args or [], **processor.kwargs or {}
|
||||
)
|
||||
logits_processors.append(logits_processor)
|
||||
return logits_processors
|
||||
elif processors:
|
||||
raise ValueError(
|
||||
"The `logits_processors` argument is not supported by this "
|
||||
"server. See --logits-processor-pattern engine argument "
|
||||
"for more information."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
# Internal field to preserve native tool call ID from tool parser.
|
||||
# Excluded from serialization to maintain OpenAI API compatibility
|
||||
# (function object should only contain 'name' and 'arguments').
|
||||
id: str | None = Field(default=None, exclude=True)
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=make_tool_call_id)
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class DeltaFunctionCall(BaseModel):
|
||||
name: str | None = None
|
||||
arguments: str | None = None
|
||||
|
||||
|
||||
# a tool call delta where everything is optional
|
||||
class DeltaToolCall(OpenAIBaseModel):
|
||||
id: str | None = None
|
||||
type: Literal["function"] | None = None
|
||||
index: int
|
||||
function: DeltaFunctionCall | None = None
|
||||
|
||||
|
||||
class ExtractedToolCallInformation(BaseModel):
|
||||
# indicate if tools were called
|
||||
tools_called: bool
|
||||
|
||||
# extracted tool calls
|
||||
tool_calls: list[ToolCall]
|
||||
|
||||
# content - per OpenAI spec, content AND tool calls can be returned rarely
|
||||
# But some models will do this intentionally
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
role: str | None = None
|
||||
content: str | None = None
|
||||
reasoning: str | None = None
|
||||
tool_calls: list[DeltaToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
token_ids: list[int]
|
||||
"""The token ids to generate text from."""
|
||||
|
||||
# features: MultiModalFeatureSpec
|
||||
# TODO (NickLucche): implement once Renderer work is completed
|
||||
features: str | None = None
|
||||
"""The processed MM inputs for the model."""
|
||||
|
||||
sampling_params: SamplingParams
|
||||
"""The sampling parameters for the model."""
|
||||
|
||||
model: str | None = None
|
||||
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
1314
vllm/entrypoints/openai/engine/serving.py
Normal file
1314
vllm/entrypoints/openai/engine/serving.py
Normal file
File diff suppressed because it is too large
Load Diff
0
vllm/entrypoints/openai/generate/__init__.py
Normal file
0
vllm/entrypoints/openai/generate/__init__.py
Normal file
166
vllm/entrypoints/openai/generate/api_router.py
Normal file
166
vllm/entrypoints/openai/generate/api_router.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
|
||||
def register_generate_api_routers(app: FastAPI):
|
||||
from vllm.entrypoints.openai.chat_completion.api_router import (
|
||||
attach_router as register_chat_api_router,
|
||||
)
|
||||
|
||||
register_chat_api_router(app)
|
||||
|
||||
from vllm.entrypoints.openai.responses.api_router import (
|
||||
attach_router as register_responses_api_router,
|
||||
)
|
||||
|
||||
register_responses_api_router(app)
|
||||
|
||||
from vllm.entrypoints.openai.completion.api_router import (
|
||||
attach_router as register_completion_api_router,
|
||||
)
|
||||
|
||||
register_completion_api_router(app)
|
||||
|
||||
from vllm.entrypoints.anthropic.api_router import (
|
||||
attach_router as register_anthropic_api_router,
|
||||
)
|
||||
|
||||
register_anthropic_api_router(app)
|
||||
|
||||
|
||||
async def init_generate_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.mcp.tool_server import (
|
||||
DemoToolServer,
|
||||
MCPToolServer,
|
||||
ToolServer,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
|
||||
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
||||
|
||||
if args.tool_server == "demo":
|
||||
tool_server: ToolServer | None = DemoToolServer()
|
||||
assert isinstance(tool_server, DemoToolServer)
|
||||
await tool_server.init_and_validate()
|
||||
elif args.tool_server:
|
||||
tool_server = MCPToolServer()
|
||||
await tool_server.add_tool_server(args.tool_server)
|
||||
else:
|
||||
tool_server = None
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
state.openai_serving_responses = (
|
||||
OpenAIServingResponses(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
tool_server=tool_server,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_chat = (
|
||||
OpenAIServingChat(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
default_chat_template_kwargs=args.default_chat_template_kwargs,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
enable_log_deltas=args.enable_log_deltas,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# Warm up chat template processing to avoid first-request latency
|
||||
if state.openai_serving_chat is not None:
|
||||
await state.openai_serving_chat.warmup()
|
||||
state.openai_serving_completion = (
|
||||
OpenAIServingCompletion(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.anthropic_serving_messages = (
|
||||
AnthropicServingMessages(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.serving_tokens = (
|
||||
ServingTokens(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
force_no_detokenize=args.tokens_only,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
0
vllm/entrypoints/openai/models/__init__.py
Normal file
0
vllm/entrypoints/openai/models/__init__.py
Normal file
29
vllm/entrypoints/openai/models/api_router.py
Normal file
29
vllm/entrypoints/openai/models/api_router.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def models(request: Request) -> OpenAIServingModels:
|
||||
return request.app.state.openai_serving_models
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
handler = models(raw_request)
|
||||
|
||||
models_ = await handler.show_available_models()
|
||||
return JSONResponse(content=models_.model_dump())
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
18
vllm/entrypoints/openai/models/protocol.py
Normal file
18
vllm/entrypoints/openai/models/protocol.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: str | None = None
|
||||
308
vllm/entrypoints/openai/models/serving.py
Normal file
308
vllm/entrypoints/openai/models/serving.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath, LoRAModulePath
|
||||
from vllm.entrypoints.serve.lora.protocol import (
|
||||
LoadLoRAAdapterRequest,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.entrypoints.utils import sanitize_message
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.utils.counter import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingModels:
|
||||
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||
|
||||
Handles the routes:
|
||||
- /v1/models
|
||||
- /v1/load_lora_adapter
|
||||
- /v1/unload_lora_adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
base_model_paths: list[BaseModelPath],
|
||||
*,
|
||||
lora_modules: list[LoRAModulePath] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers():
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name)
|
||||
)
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.model_config = self.engine_client.model_config
|
||||
self.renderer = self.engine_client.renderer
|
||||
self.io_processor = self.engine_client.io_processor
|
||||
self.input_processor = self.engine_client.input_processor
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
if self.static_lora_modules is None:
|
||||
return
|
||||
for lora in self.static_lora_modules:
|
||||
load_request = LoadLoRAAdapterRequest(
|
||||
lora_path=lora.path, lora_name=lora.name
|
||||
)
|
||||
load_result = await self.load_lora_adapter(
|
||||
request=load_request, base_model_name=lora.base_model_name
|
||||
)
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.error.message)
|
||||
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: LoRARequest | None = None) -> str:
|
||||
"""Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora_request is not None:
|
||||
return lora_request.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all adapters."""
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
model_cards = [
|
||||
ModelCard(
|
||||
id=base_model.name,
|
||||
max_model_len=max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(
|
||||
id=lora.lora_name,
|
||||
root=lora.path,
|
||||
parent=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
else self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
for lora in self.lora_requests.values()
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def load_lora_adapter(
|
||||
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
|
||||
) -> ErrorResponse | str:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_path = request.lora_path
|
||||
lora_int_id = (
|
||||
self.lora_requests[lora_name].lora_int_id
|
||||
if lora_name in self.lora_requests
|
||||
else self.lora_id_counter.inc(1)
|
||||
)
|
||||
lora_request = LoRARequest(
|
||||
lora_name=lora_name,
|
||||
lora_int_id=lora_int_id,
|
||||
lora_path=lora_path,
|
||||
load_inplace=request.load_inplace,
|
||||
)
|
||||
if base_model_name is not None and self.is_base_model(base_model_name):
|
||||
lora_request.base_model_name = base_model_name
|
||||
|
||||
# Validate that the adapter can be loaded into the engine
|
||||
# This will also preload it for incoming requests
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
except Exception as e:
|
||||
error_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if "No adapter found" in str(e):
|
||||
error_type = "NotFoundError"
|
||||
status_code = HTTPStatus.NOT_FOUND
|
||||
|
||||
return create_error_response(
|
||||
message=str(e), err_type=error_type, status_code=status_code
|
||||
)
|
||||
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info(
|
||||
"Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path
|
||||
)
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self, request: UnloadLoRAAdapterRequest
|
||||
) -> ErrorResponse | str:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# Safe to delete now since we hold the lock
|
||||
del self.lora_requests[lora_name]
|
||||
logger.info("Removed LoRA adapter: name '%s'", lora_name)
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoRAAdapterRequest
|
||||
) -> ErrorResponse | None:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
# If not loading inplace
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if not request.load_inplace and request.lora_name in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=f"The lora adapter '{request.lora_name}' has already been "
|
||||
"loaded. If you want to load the adapter in place, set 'load_inplace'"
|
||||
" to True.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self, request: UnloadLoRAAdapterRequest
|
||||
) -> ErrorResponse | None:
|
||||
# Check if 'lora_name' is not provided return an error
|
||||
if not request.lora_name:
|
||||
return create_error_response(
|
||||
message="'lora_name' needs to be provided to unload a LoRA adapter.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if request.lora_name not in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(self, lora_name: str) -> LoRARequest | ErrorResponse:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
if lora_name in self.lora_requests:
|
||||
return self.lora_requests[lora_name]
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name,
|
||||
resolver.__class__.__name__,
|
||||
)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.",
|
||||
lora_name,
|
||||
resolver.__class__.__name__,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(
|
||||
f"LoRA adapter '{lora_name}' was found but could not be loaded."
|
||||
),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
) -> ErrorResponse:
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(message),
|
||||
type=err_type,
|
||||
code=status_code.value,
|
||||
)
|
||||
)
|
||||
120
vllm/entrypoints/openai/orca_metrics.py
Normal file
120
vllm/entrypoints/openai/orca_metrics.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Utility functions that create ORCA endpoint load report response headers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_orca_header(
|
||||
metrics_format: str, named_metrics: list[tuple[str, float]]
|
||||
) -> Mapping[str, str] | None:
|
||||
"""
|
||||
Creates ORCA headers named 'endpoint-load-metrics' in the specified format
|
||||
and adds custom metrics to named_metrics.
|
||||
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
|
||||
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
|
||||
|
||||
Parameters:
|
||||
- metrics_format (str): The format of the header ('TEXT', 'JSON').
|
||||
- named_metrics (List[Tuple[str, float]]): List of tuples with metric names
|
||||
and their corresponding double values.
|
||||
|
||||
Returns:
|
||||
- Optional[Mapping[str,str]]: A dictionary with header key as
|
||||
'endpoint-load-metrics' and values as the ORCA header strings with
|
||||
format prefix and data in with named_metrics in.
|
||||
"""
|
||||
|
||||
if metrics_format.lower() not in ["text", "json"]:
|
||||
logger.warning(
|
||||
"Warning: `%s` format is not supported in the ORCA response header",
|
||||
format,
|
||||
)
|
||||
return None
|
||||
|
||||
header = {}
|
||||
orca_report = {
|
||||
"named_metrics": {
|
||||
metric_name: value
|
||||
for metric_name, value in named_metrics
|
||||
if isinstance(metric_name, str) and isinstance(value, float)
|
||||
}
|
||||
}
|
||||
# output example:
|
||||
# endpoint-load-metrics: TEXT named_metrics.kv_cache_utilization=0.4
|
||||
if metrics_format.lower() == "text":
|
||||
native_http_header = ", ".join(
|
||||
[
|
||||
f"named_metrics.{metric_name}={value}"
|
||||
for metric_name, value in named_metrics
|
||||
if isinstance(metric_name, str) and isinstance(value, float)
|
||||
]
|
||||
)
|
||||
header["endpoint-load-metrics"] = f"TEXT {native_http_header}"
|
||||
|
||||
# output example:
|
||||
# endpoint-load-metrics: JSON “named_metrics”: {“custom-metric-util”: 0.4}
|
||||
elif metrics_format.lower() == "json":
|
||||
header["endpoint-load-metrics"] = f"JSON {json.dumps(orca_report)}"
|
||||
|
||||
logger.info("Created ORCA header %s", header)
|
||||
|
||||
return header
|
||||
|
||||
|
||||
def get_named_metrics_from_prometheus() -> list[tuple[str, float]]:
|
||||
"""
|
||||
Collects current metrics from Prometheus and returns some of them
|
||||
in the form of the `named_metrics` list for `create_orca_header()`.
|
||||
|
||||
Parameters:
|
||||
- None
|
||||
|
||||
Returns:
|
||||
- list[tuple[str, float]]: List of tuples of metric names and their values.
|
||||
"""
|
||||
named_metrics: list[tuple[str, float]] = []
|
||||
# Map from prometheus metric names to ORCA named metrics.
|
||||
prometheus_to_orca_metrics = {
|
||||
"vllm:kv_cache_usage_perc": "kv_cache_usage_perc",
|
||||
"vllm:num_requests_waiting": "num_requests_waiting",
|
||||
}
|
||||
metrics = get_metrics_snapshot()
|
||||
for metric in metrics:
|
||||
orca_name = prometheus_to_orca_metrics.get(metric.name)
|
||||
# If this metric is mapped into ORCA, then add it to the report.
|
||||
# Note: Only Gauge metrics are currently supported.
|
||||
if orca_name is not None and isinstance(metric, Gauge):
|
||||
named_metrics.append((str(orca_name), float(metric.value)))
|
||||
return named_metrics
|
||||
|
||||
|
||||
def metrics_header(metrics_format: str) -> Mapping[str, str] | None:
|
||||
"""
|
||||
Creates ORCA headers named 'endpoint-load-metrics' in the specified format.
|
||||
Metrics are collected from Prometheus using `get_named_metrics_from_prometheus()`.
|
||||
|
||||
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
|
||||
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
|
||||
|
||||
Parameters:
|
||||
- metrics_format (str): The format of the header ('TEXT', 'JSON').
|
||||
|
||||
Returns:
|
||||
- Optional[Mapping[str,str]]: A dictionary with header key as
|
||||
'endpoint-load-metrics' and values as the ORCA header strings with
|
||||
format prefix and data in with named_metrics in.
|
||||
"""
|
||||
if not metrics_format:
|
||||
return None
|
||||
# Get named metrics from prometheus.
|
||||
named_metrics = get_named_metrics_from_prometheus()
|
||||
return create_orca_header(metrics_format, named_metrics)
|
||||
0
vllm/entrypoints/openai/parser/__init__.py
Normal file
0
vllm/entrypoints/openai/parser/__init__.py
Normal file
394
vllm/entrypoints/openai/parser/harmony_utils.py
Normal file
394
vllm/entrypoints/openai/parser/harmony_utils.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import datetime
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Literal
|
||||
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import (
|
||||
Author,
|
||||
Conversation,
|
||||
DeveloperContent,
|
||||
HarmonyEncodingName,
|
||||
Message,
|
||||
ReasoningEffort,
|
||||
Role,
|
||||
StreamableParser,
|
||||
SystemContent,
|
||||
TextContent,
|
||||
ToolDescription,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionToolsParam
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
REASONING_EFFORT = {
|
||||
"high": ReasoningEffort.HIGH,
|
||||
"medium": ReasoningEffort.MEDIUM,
|
||||
"low": ReasoningEffort.LOW,
|
||||
}
|
||||
|
||||
_harmony_encoding = None
|
||||
|
||||
# Builtin tools that should be included in the system message when
|
||||
# they are available and requested by the user.
|
||||
# Tool args are provided by MCP tool descriptions. Output
|
||||
# of the tools are stringified.
|
||||
BUILTIN_TOOL_TO_MCP_SERVER_LABEL: dict[str, str] = {
|
||||
"python": "code_interpreter",
|
||||
"browser": "web_search_preview",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
# Derive MCP_BUILTIN_TOOLS from the canonical mapping
|
||||
MCP_BUILTIN_TOOLS: set[str] = set(BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values())
|
||||
|
||||
|
||||
def has_custom_tools(tool_types: set[str]) -> bool:
|
||||
"""
|
||||
Checks if the given tool types are custom tools
|
||||
(i.e. any tool other than MCP buildin tools)
|
||||
"""
|
||||
return not tool_types.issubset(MCP_BUILTIN_TOOLS)
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _harmony_encoding
|
||||
if _harmony_encoding is None:
|
||||
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return _harmony_encoding
|
||||
|
||||
|
||||
def get_system_message(
|
||||
model_identity: str | None = None,
|
||||
reasoning_effort: Literal["high", "medium", "low"] | None = None,
|
||||
start_date: str | None = None,
|
||||
browser_description: str | None = None,
|
||||
python_description: str | None = None,
|
||||
container_description: str | None = None,
|
||||
instructions: str | None = None,
|
||||
with_custom_tools: bool = False,
|
||||
) -> Message:
|
||||
sys_msg_content = SystemContent.new()
|
||||
if model_identity is not None:
|
||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||
if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
|
||||
current_identity = sys_msg_content.model_identity
|
||||
new_identity = (
|
||||
f"{current_identity}\n{instructions}" if current_identity else instructions
|
||||
)
|
||||
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
|
||||
if reasoning_effort is not None:
|
||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||
REASONING_EFFORT[reasoning_effort]
|
||||
)
|
||||
if start_date is None:
|
||||
# NOTE(woosuk): This brings non-determinism in vLLM.
|
||||
# Set VLLM_SYSTEM_START_DATE to pin it.
|
||||
start_date = envs.VLLM_SYSTEM_START_DATE or datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
|
||||
if browser_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
if python_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||
if container_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(container_description)
|
||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||
return sys_msg
|
||||
|
||||
|
||||
def create_tool_definition(tool: ChatCompletionToolsParam | Tool):
|
||||
if isinstance(tool, ChatCompletionToolsParam):
|
||||
return ToolDescription.new(
|
||||
name=tool.function.name,
|
||||
description=tool.function.description,
|
||||
parameters=tool.function.parameters,
|
||||
)
|
||||
return ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
)
|
||||
|
||||
|
||||
def get_developer_message(
|
||||
instructions: str | None = None,
|
||||
tools: list[Tool | ChatCompletionToolsParam] | None = None,
|
||||
) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools: list[Tool | ChatCompletionToolsParam] = []
|
||||
for tool in tools:
|
||||
if tool.type in (
|
||||
"web_search_preview",
|
||||
"code_interpreter",
|
||||
"container",
|
||||
):
|
||||
pass
|
||||
|
||||
elif tool.type == "function":
|
||||
function_tools.append(tool)
|
||||
else:
|
||||
raise ValueError(f"tool type {tool.type} not supported")
|
||||
if function_tools:
|
||||
function_tool_descriptions = [
|
||||
create_tool_definition(tool) for tool in function_tools
|
||||
]
|
||||
dev_msg_content = dev_msg_content.with_function_tools(
|
||||
function_tool_descriptions
|
||||
)
|
||||
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
||||
return dev_msg
|
||||
|
||||
|
||||
def get_user_message(content: str) -> Message:
|
||||
return Message.from_role_and_content(Role.USER, content)
|
||||
|
||||
|
||||
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
|
||||
"""
|
||||
Parse a list of messages from request.messages in the Chat Completion API to
|
||||
Harmony messages.
|
||||
"""
|
||||
msgs: list[Message] = []
|
||||
tool_id_names: dict[str, str] = {}
|
||||
|
||||
# Collect tool id to name mappings for tool response recipient values
|
||||
for chat_msg in chat_msgs:
|
||||
for tool_call in chat_msg.get("tool_calls", []):
|
||||
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
|
||||
"name"
|
||||
)
|
||||
|
||||
for chat_msg in chat_msgs:
|
||||
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
|
||||
|
||||
msgs = auto_drop_analysis_messages(msgs)
|
||||
return msgs
|
||||
|
||||
|
||||
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
|
||||
"""
|
||||
Harmony models expect the analysis messages (representing raw chain of thought) to
|
||||
be dropped after an assistant message to the final channel is produced from the
|
||||
reasoning of those messages.
|
||||
|
||||
The openai-harmony library does this if the very last assistant message is to the
|
||||
final channel, but it does not handle the case where we're in longer multi-turn
|
||||
conversations and the client gave us reasoning content from previous turns of
|
||||
the conversation with multiple assistant messages to the final channel in the
|
||||
conversation.
|
||||
|
||||
So, we find the index of the last assistant message to the final channel and drop
|
||||
all analysis messages that precede it, leaving only the analysis messages that
|
||||
are relevant to the current part of the conversation.
|
||||
"""
|
||||
last_assistant_final_index = -1
|
||||
for i in range(len(msgs) - 1, -1, -1):
|
||||
msg = msgs[i]
|
||||
if msg.author.role == "assistant" and msg.channel == "final":
|
||||
last_assistant_final_index = i
|
||||
break
|
||||
|
||||
cleaned_msgs: list[Message] = []
|
||||
for i, msg in enumerate(msgs):
|
||||
if i < last_assistant_final_index and msg.channel == "analysis":
|
||||
continue
|
||||
cleaned_msgs.append(msg)
|
||||
|
||||
return cleaned_msgs
|
||||
|
||||
|
||||
def flatten_chat_text_content(content: str | list | None) -> str | None:
|
||||
"""
|
||||
Extract the text parts from a chat message content field and flatten them
|
||||
into a single string.
|
||||
"""
|
||||
if isinstance(content, list):
|
||||
return "".join(
|
||||
item.get("text", "")
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names: dict[str, str] | None = None
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Parse a message from request.messages in the Chat Completion API to
|
||||
Harmony messages.
|
||||
"""
|
||||
tool_id_names = tool_id_names or {}
|
||||
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
role = chat_msg.get("role")
|
||||
msgs: list[Message] = []
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls", [])
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
content = flatten_chat_text_content(chat_msg.get("content"))
|
||||
if content:
|
||||
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
|
||||
commentary_msg = commentary_msg.with_channel("commentary")
|
||||
msgs.append(commentary_msg)
|
||||
|
||||
reasoning = chat_msg.get("reasoning")
|
||||
if reasoning:
|
||||
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning)
|
||||
analysis_msg = analysis_msg.with_channel("analysis")
|
||||
msgs.append(analysis_msg)
|
||||
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
# Officially, this should be `<|constrain|>json` but there is not clear
|
||||
# evidence that improves accuracy over `json` and some anecdotes to the
|
||||
# contrary. Further testing of the different content_types is needed.
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
tool_call_id = chat_msg.get("tool_call_id", "")
|
||||
name = tool_id_names.get(tool_call_id, "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
content = flatten_chat_text_content(content)
|
||||
|
||||
msg = (
|
||||
Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"), content
|
||||
)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("assistant")
|
||||
)
|
||||
return [msg]
|
||||
|
||||
# Non-tool reasoning content
|
||||
reasoning = chat_msg.get("reasoning")
|
||||
if role == "assistant" and reasoning:
|
||||
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning)
|
||||
analysis_msg = analysis_msg.with_channel("analysis")
|
||||
msgs.append(analysis_msg)
|
||||
|
||||
# Default: user/assistant/system messages with content
|
||||
content = chat_msg.get("content") or ""
|
||||
if content is None:
|
||||
content = ""
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
|
||||
# Only add assistant messages if they have content, as reasoning or tool calling
|
||||
# assistant messages were already added above.
|
||||
if role == "assistant" and contents and contents[0].text:
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
# Send non-tool assistant messages to the final channel
|
||||
msg = msg.with_channel("final")
|
||||
msgs.append(msg)
|
||||
# For user/system/developer messages, add them directly even if no content.
|
||||
elif role != "assistant":
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
msgs.append(msg)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||
conversation = Conversation.from_messages(messages)
|
||||
token_ids = get_encoding().render_conversation_for_completion(
|
||||
conversation, Role.ASSISTANT
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
||||
def get_stop_tokens_for_assistant_actions() -> list[int]:
|
||||
return get_encoding().stop_tokens_for_assistant_actions()
|
||||
|
||||
|
||||
def get_streamable_parser_for_assistant() -> StreamableParser:
|
||||
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
||||
|
||||
|
||||
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
||||
parser = get_streamable_parser_for_assistant()
|
||||
for token_id in token_ids:
|
||||
parser.process(token_id)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_chat_output(
|
||||
token_ids: Sequence[int],
|
||||
) -> tuple[str | None, str | None, bool]:
|
||||
"""
|
||||
Parse the output of a Harmony chat completion into reasoning and final content.
|
||||
Note that when the `openai` tool parser is used, serving_chat only uses this
|
||||
for the reasoning content and gets the final content from the tool call parser.
|
||||
|
||||
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
|
||||
in use,this needs to return the final content without any tool calls parsed.
|
||||
|
||||
Empty reasoning or final content is returned as None instead of an empty string.
|
||||
"""
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
output_msgs = parser.messages
|
||||
is_tool_call = False # TODO: update this when tool call is supported
|
||||
|
||||
# Get completed messages from the parser
|
||||
# - analysis channel: hidden reasoning
|
||||
# - commentary channel without recipient (preambles): visible to user
|
||||
# - final channel: visible to user
|
||||
# - commentary with recipient (tool calls): handled separately by tool parser
|
||||
reasoning_texts = [
|
||||
msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
|
||||
]
|
||||
final_texts = [
|
||||
msg.content[0].text
|
||||
for msg in output_msgs
|
||||
if msg.channel == "final" or (msg.channel == "commentary" and not msg.recipient)
|
||||
]
|
||||
|
||||
# Extract partial messages from the parser
|
||||
if parser.current_channel == "analysis" and parser.current_content:
|
||||
reasoning_texts.append(parser.current_content)
|
||||
elif parser.current_channel == "final" and parser.current_content:
|
||||
final_texts.append(parser.current_content)
|
||||
elif (
|
||||
parser.current_channel == "commentary"
|
||||
and not parser.current_recipient
|
||||
and parser.current_content
|
||||
):
|
||||
# Preambles (commentary without recipient) are visible to user
|
||||
final_texts.append(parser.current_content)
|
||||
|
||||
# Flatten multiple messages into a single string
|
||||
reasoning: str | None = "\n".join(reasoning_texts)
|
||||
final_content: str | None = "\n".join(final_texts)
|
||||
|
||||
# Return None instead of empty string since existing callers check for None
|
||||
reasoning = reasoning or None
|
||||
final_content = final_content or None
|
||||
|
||||
return reasoning, final_content, is_tool_call
|
||||
179
vllm/entrypoints/openai/parser/responses_parser.py
Normal file
179
vllm/entrypoints/openai/parser/responses_parser.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_output_text import ResponseOutputText
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponsesParser:
|
||||
"""Incremental parser over completion tokens with reasoning support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
):
|
||||
self.response_messages: list[ResponseInputOutputItem] = (
|
||||
# TODO: initial messages may not be properly typed
|
||||
response_messages
|
||||
)
|
||||
self.num_init_messages = len(response_messages)
|
||||
self.tokenizer = tokenizer
|
||||
self.request = request
|
||||
|
||||
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
|
||||
self.tool_parser_instance = None
|
||||
if tool_parser_cls is not None:
|
||||
self.tool_parser_instance = tool_parser_cls(tokenizer)
|
||||
|
||||
# Store the last finish_reason to determine response status
|
||||
self.finish_reason: str | None = None
|
||||
|
||||
def process(self, output: CompletionOutput) -> "ResponsesParser":
|
||||
# Store the finish_reason from the output
|
||||
self.finish_reason = output.finish_reason
|
||||
|
||||
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
|
||||
output.text, request=self.request
|
||||
)
|
||||
if reasoning_content:
|
||||
self.response_messages.append(
|
||||
ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
content=[
|
||||
Content(
|
||||
type="reasoning_text",
|
||||
text=reasoning_content,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
function_calls: list[ResponseFunctionToolCall] = []
|
||||
if self.tool_parser_instance is not None:
|
||||
tool_call_info = self.tool_parser_instance.extract_tool_calls(
|
||||
content if content is not None else "",
|
||||
request=self.request, # type: ignore
|
||||
)
|
||||
if tool_call_info is not None and tool_call_info.tools_called:
|
||||
# extract_tool_calls() returns a list of tool calls.
|
||||
function_calls.extend(
|
||||
ResponseFunctionToolCall(
|
||||
id=f"fc_{random_uuid()}",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
type="function_call",
|
||||
status="completed",
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
for tool_call in tool_call_info.tool_calls
|
||||
)
|
||||
content = tool_call_info.content
|
||||
if content and content.strip() == "":
|
||||
content = None
|
||||
|
||||
if content:
|
||||
self.response_messages.append(
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id=f"msg_{random_uuid()}",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
text=content,
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
if len(function_calls) > 0:
|
||||
self.response_messages.extend(function_calls)
|
||||
|
||||
return self
|
||||
|
||||
def make_response_output_items_from_parsable_context(
|
||||
self,
|
||||
) -> list[ResponseOutputItem]:
|
||||
"""Given a list of sentences, construct ResponseOutput Items."""
|
||||
response_messages = self.response_messages[self.num_init_messages :]
|
||||
output_messages: list[ResponseOutputItem] = []
|
||||
for message in response_messages:
|
||||
if not isinstance(message, ResponseFunctionToolCallOutputItem):
|
||||
output_messages.append(message)
|
||||
else:
|
||||
if len(output_messages) == 0:
|
||||
raise ValueError(
|
||||
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
|
||||
)
|
||||
if isinstance(output_messages[-1], ResponseFunctionToolCall):
|
||||
mcp_message = McpCall(
|
||||
id=f"{MCP_PREFIX}{random_uuid()}",
|
||||
arguments=output_messages[-1].arguments,
|
||||
name=output_messages[-1].name,
|
||||
server_label=output_messages[
|
||||
-1
|
||||
].name, # TODO: store the server label
|
||||
type="mcp_call",
|
||||
status="completed",
|
||||
output=message.output,
|
||||
# TODO: support error output
|
||||
)
|
||||
output_messages[-1] = mcp_message
|
||||
|
||||
return output_messages
|
||||
|
||||
|
||||
def get_responses_parser_for_simple_context(
|
||||
*,
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls,
|
||||
) -> ResponsesParser:
|
||||
"""Factory function to create a ResponsesParser with
|
||||
optional reasoning parser.
|
||||
|
||||
Returns:
|
||||
ResponsesParser instance configured with the provided parser
|
||||
"""
|
||||
return ResponsesParser(
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
tool_parser_cls=tool_parser_cls,
|
||||
)
|
||||
2
vllm/entrypoints/openai/realtime/__init__.py
Normal file
2
vllm/entrypoints/openai/realtime/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
75
vllm/entrypoints/openai/realtime/api_router.py
Normal file
75
vllm/entrypoints/openai/realtime/api_router.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, FastAPI, WebSocket
|
||||
|
||||
from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
|
||||
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/v1/realtime")
|
||||
async def realtime_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for realtime audio transcription.
|
||||
|
||||
Protocol:
|
||||
1. Client connects to ws://host/v1/realtime
|
||||
2. Server sends session.created event
|
||||
3. Client optionally sends session.update with model/params
|
||||
4. Client sends input_audio_buffer.commit when ready
|
||||
5. Client sends input_audio_buffer.append events with base64 PCM16 chunks
|
||||
6. Server processes and sends transcription.delta events
|
||||
7. Server sends transcription.done with final text + usage
|
||||
8. Repeat from step 5 for next utterance
|
||||
9. Optionally, client sends input_audio_buffer.commit with final=True
|
||||
to signal audio input is finished. Useful when streaming audio files
|
||||
|
||||
Audio format: PCM16, 16kHz, mono, base64-encoded
|
||||
"""
|
||||
app = websocket.app
|
||||
serving = app.state.openai_serving_realtime
|
||||
|
||||
connection = RealtimeConnection(websocket, serving)
|
||||
await connection.handle_connection()
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Attach the realtime router to the FastAPI app."""
|
||||
app.include_router(router)
|
||||
logger.info("Realtime API router attached")
|
||||
|
||||
|
||||
def init_realtime_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
state.openai_serving_realtime = (
|
||||
OpenAIServingRealtime(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "realtime" in supported_tasks
|
||||
else None
|
||||
)
|
||||
279
vllm/entrypoints/openai/realtime/connection.py
Normal file
279
vllm/entrypoints/openai/realtime/connection.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.realtime.protocol import (
|
||||
ErrorEvent,
|
||||
InputAudioBufferAppend,
|
||||
InputAudioBufferCommit,
|
||||
SessionCreated,
|
||||
TranscriptionDelta,
|
||||
TranscriptionDone,
|
||||
)
|
||||
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RealtimeConnection:
|
||||
"""Manages WebSocket lifecycle and state for realtime transcription.
|
||||
|
||||
This class handles:
|
||||
- WebSocket connection lifecycle (accept, receive, send, close)
|
||||
- Event routing (session.update, append, commit)
|
||||
- Audio buffering via asyncio.Queue
|
||||
- Generation task management
|
||||
- Error handling and cleanup
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, serving: OpenAIServingRealtime):
|
||||
self.websocket = websocket
|
||||
self.connection_id = f"ws-{uuid4()}"
|
||||
self.serving = serving
|
||||
self.audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
|
||||
self.generation_task: asyncio.Task | None = None
|
||||
|
||||
self._is_connected = False
|
||||
self._is_model_validated = False
|
||||
|
||||
self._max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
|
||||
async def handle_connection(self):
|
||||
"""Main connection loop."""
|
||||
await self.websocket.accept()
|
||||
logger.debug("WebSocket connection accepted: %s", self.connection_id)
|
||||
self._is_connected = True
|
||||
|
||||
# Send session created event
|
||||
await self.send(SessionCreated())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await self.websocket.receive_text()
|
||||
try:
|
||||
event = json.loads(message)
|
||||
await self.handle_event(event)
|
||||
except json.JSONDecodeError:
|
||||
await self.send_error("Invalid JSON", "invalid_json")
|
||||
except Exception as e:
|
||||
logger.exception("Error handling event: %s", e)
|
||||
await self.send_error(str(e), "processing_error")
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket disconnected: %s", self.connection_id)
|
||||
self._is_connected = False
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error in connection: %s", e)
|
||||
finally:
|
||||
await self.cleanup()
|
||||
|
||||
def _check_model(self, model: str | None) -> None | ErrorResponse:
|
||||
if self.serving._is_model_supported(model):
|
||||
return None
|
||||
|
||||
return self.serving.create_error_response(
|
||||
message=f"The model `{model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="model",
|
||||
)
|
||||
|
||||
async def handle_event(self, event: dict):
|
||||
"""Route events to handlers.
|
||||
|
||||
Supported event types:
|
||||
- session.update: Configure model
|
||||
- input_audio_buffer.append: Add audio chunk to queue
|
||||
- input_audio_buffer.commit: Start transcription generation
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
if event_type == "session.update":
|
||||
logger.debug("Session updated: %s", event)
|
||||
self._check_model(event["model"])
|
||||
self._is_model_validated = True
|
||||
elif event_type == "input_audio_buffer.append":
|
||||
append_event = InputAudioBufferAppend(**event)
|
||||
try:
|
||||
audio_bytes = base64.b64decode(append_event.audio)
|
||||
# Convert PCM16 bytes to float32 numpy array
|
||||
audio_array = (
|
||||
np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
|
||||
/ 32768.0
|
||||
)
|
||||
|
||||
if len(audio_array) / 1024**2 > self._max_audio_filesize_mb:
|
||||
raise VLLMValidationError(
|
||||
"Maximum file size exceeded",
|
||||
parameter="audio_filesize_mb",
|
||||
value=len(audio_array) / 1024**2,
|
||||
)
|
||||
if len(audio_array) == 0:
|
||||
raise VLLMValidationError("Can't process empty audio.")
|
||||
|
||||
# Put audio chunk in queue
|
||||
self.audio_queue.put_nowait(audio_array)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to decode audio: %s", e)
|
||||
await self.send_error("Invalid audio data", "invalid_audio")
|
||||
|
||||
elif event_type == "input_audio_buffer.commit":
|
||||
if not self._is_model_validated:
|
||||
err_msg = (
|
||||
"Model not validated. Make sure to validate the"
|
||||
" model by sending a session.update event."
|
||||
)
|
||||
await self.send_error(
|
||||
err_msg,
|
||||
"model_not_validated",
|
||||
)
|
||||
|
||||
commit_event = InputAudioBufferCommit(**event)
|
||||
# final signals that the audio is finished
|
||||
if commit_event.final:
|
||||
self.audio_queue.put_nowait(None)
|
||||
else:
|
||||
await self.start_generation()
|
||||
else:
|
||||
await self.send_error(f"Unknown event type: {event_type}", "unknown_event")
|
||||
|
||||
async def audio_stream_generator(self) -> AsyncGenerator[np.ndarray, None]:
|
||||
"""Generator that yields audio chunks from the queue."""
|
||||
while True:
|
||||
audio_chunk = await self.audio_queue.get()
|
||||
if audio_chunk is None: # Sentinel value to stop
|
||||
break
|
||||
yield audio_chunk
|
||||
|
||||
async def start_generation(self):
|
||||
"""Start the transcription generation task."""
|
||||
if self.generation_task is not None and not self.generation_task.done():
|
||||
logger.warning("Generation already in progress, ignoring commit")
|
||||
return
|
||||
|
||||
# Create audio stream generator
|
||||
audio_stream = self.audio_stream_generator()
|
||||
input_stream = asyncio.Queue[list[int]]()
|
||||
|
||||
# Transform to StreamingInput generator
|
||||
streaming_input_gen = self.serving.transcribe_realtime(
|
||||
audio_stream, input_stream
|
||||
)
|
||||
|
||||
# Start generation task
|
||||
self.generation_task = asyncio.create_task(
|
||||
self._run_generation(streaming_input_gen, input_stream)
|
||||
)
|
||||
|
||||
async def _run_generation(
|
||||
self,
|
||||
streaming_input_gen: AsyncGenerator,
|
||||
input_stream: asyncio.Queue[list[int]],
|
||||
):
|
||||
"""Run the generation and stream results back to the client.
|
||||
|
||||
This method:
|
||||
1. Creates sampling parameters from session config
|
||||
2. Passes the streaming input generator to engine.generate()
|
||||
3. Streams transcription.delta events as text is generated
|
||||
4. Sends final transcription.done event with usage stats
|
||||
5. Feeds generated token IDs back to input_stream for next iteration
|
||||
6. Cleans up the audio queue
|
||||
"""
|
||||
request_id = f"rt-{self.connection_id}-{uuid4()}"
|
||||
full_text = ""
|
||||
|
||||
prompt_token_ids_len: int = 0
|
||||
completion_tokens_len: int = 0
|
||||
|
||||
try:
|
||||
# Create sampling params
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams.from_optional(
|
||||
temperature=0.0,
|
||||
max_tokens=self.serving.model_cls.realtime_max_tokens,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
skip_clone=True,
|
||||
)
|
||||
|
||||
# Pass the streaming input generator to the engine
|
||||
# The engine will consume audio chunks as they arrive and
|
||||
# stream back transcription results incrementally
|
||||
result_gen = self.serving.engine_client.generate(
|
||||
prompt=streaming_input_gen,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Stream results back to client as they're generated
|
||||
async for output in result_gen:
|
||||
if output.outputs and len(output.outputs) > 0:
|
||||
if not prompt_token_ids_len and output.prompt_token_ids:
|
||||
prompt_token_ids_len = len(output.prompt_token_ids)
|
||||
|
||||
delta = output.outputs[0].text
|
||||
full_text += delta
|
||||
|
||||
# append output to input
|
||||
input_stream.put_nowait(list(output.outputs[0].token_ids))
|
||||
await self.send(TranscriptionDelta(delta=delta))
|
||||
|
||||
completion_tokens_len += len(output.outputs[0].token_ids)
|
||||
|
||||
if not self._is_connected:
|
||||
# finish because websocket connection was killed
|
||||
break
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_token_ids_len,
|
||||
completion_tokens=completion_tokens_len,
|
||||
total_tokens=prompt_token_ids_len + completion_tokens_len,
|
||||
)
|
||||
|
||||
# Send final completion event
|
||||
await self.send(TranscriptionDone(text=full_text, usage=usage))
|
||||
|
||||
# Clear queue for next utterance
|
||||
while not self.audio_queue.empty():
|
||||
self.audio_queue.get_nowait()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in generation: %s", e)
|
||||
await self.send_error(str(e), "processing_error")
|
||||
|
||||
async def send(
|
||||
self, event: SessionCreated | TranscriptionDelta | TranscriptionDone
|
||||
):
|
||||
"""Send event to client."""
|
||||
data = event.model_dump_json()
|
||||
await self.websocket.send_text(data)
|
||||
|
||||
async def send_error(self, message: str, code: str | None = None):
|
||||
"""Send error event to client."""
|
||||
error_event = ErrorEvent(error=message, code=code)
|
||||
await self.websocket.send_text(error_event.model_dump_json())
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
# Signal audio stream to stop
|
||||
self.audio_queue.put_nowait(None)
|
||||
|
||||
# Cancel generation task if running
|
||||
if self.generation_task and not self.generation_task.done():
|
||||
self.generation_task.cancel()
|
||||
|
||||
logger.debug("Connection cleanup complete: %s", self.connection_id)
|
||||
68
vllm/entrypoints/openai/realtime/protocol.py
Normal file
68
vllm/entrypoints/openai/realtime/protocol.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
OpenAIBaseModel,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# Client -> Server Events
|
||||
|
||||
|
||||
class InputAudioBufferAppend(OpenAIBaseModel):
|
||||
"""Append audio chunk to buffer"""
|
||||
|
||||
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
|
||||
audio: str # base64-encoded PCM16 @ 16kHz
|
||||
|
||||
|
||||
class InputAudioBufferCommit(OpenAIBaseModel):
|
||||
"""Process accumulated audio buffer"""
|
||||
|
||||
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
|
||||
final: bool = False
|
||||
|
||||
|
||||
# Server -> Client Events
|
||||
class SessionUpdate(OpenAIBaseModel):
|
||||
"""Configure session parameters"""
|
||||
|
||||
type: Literal["session.update"] = "session.update"
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class SessionCreated(OpenAIBaseModel):
|
||||
"""Connection established notification"""
|
||||
|
||||
type: Literal["session.created"] = "session.created"
|
||||
id: str = Field(default_factory=lambda: f"sess-{random_uuid()}")
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
|
||||
class TranscriptionDelta(OpenAIBaseModel):
|
||||
"""Incremental transcription text"""
|
||||
|
||||
type: Literal["transcription.delta"] = "transcription.delta"
|
||||
delta: str # Incremental text
|
||||
|
||||
|
||||
class TranscriptionDone(OpenAIBaseModel):
|
||||
"""Final transcription with usage stats"""
|
||||
|
||||
type: Literal["transcription.done"] = "transcription.done"
|
||||
text: str # Complete transcription
|
||||
usage: UsageInfo | None = None
|
||||
|
||||
|
||||
class ErrorEvent(OpenAIBaseModel):
|
||||
"""Error notification"""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: str
|
||||
code: str | None = None
|
||||
90
vllm/entrypoints/openai/realtime/serving.py
Normal file
90
vllm/entrypoints/openai/realtime/serving.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import cached_property
|
||||
from typing import Literal, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.engine.protocol import EngineClient, StreamingInput
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsRealtime
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingRealtime(OpenAIServing):
|
||||
"""Realtime audio transcription service via WebSocket streaming.
|
||||
|
||||
Provides streaming audio-to-text transcription by transforming audio chunks
|
||||
into StreamingInput objects that can be consumed by the engine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.task_type: Literal["realtime"] = "realtime"
|
||||
|
||||
logger.info("OpenAIServingRealtime initialized for task: %s", self.task_type)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsRealtime]:
|
||||
"""Get the model class that supports transcription."""
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsRealtime], model_cls)
|
||||
|
||||
async def transcribe_realtime(
|
||||
self,
|
||||
audio_stream: AsyncGenerator[np.ndarray, None],
|
||||
input_stream: asyncio.Queue[list[int]],
|
||||
) -> AsyncGenerator[StreamingInput, None]:
|
||||
"""Transform audio stream into StreamingInput for engine.generate().
|
||||
|
||||
Args:
|
||||
audio_stream: Async generator yielding float32 numpy audio arrays
|
||||
input_stream: Queue containing context token IDs from previous
|
||||
generation outputs. Used for autoregressive multi-turn
|
||||
processing where each generation's output becomes the context
|
||||
for the next iteration.
|
||||
|
||||
Yields:
|
||||
StreamingInput objects containing audio prompts for the engine
|
||||
"""
|
||||
model_config = self.model_config
|
||||
renderer = self.renderer
|
||||
|
||||
# mypy is being stupid
|
||||
# TODO(Patrick) - fix this
|
||||
stream_input_iter = cast(
|
||||
AsyncGenerator[PromptType, None],
|
||||
self.model_cls.buffer_realtime_audio(
|
||||
audio_stream, input_stream, model_config
|
||||
),
|
||||
)
|
||||
|
||||
async for prompt in stream_input_iter:
|
||||
parsed_prompt = parse_model_prompt(model_config, prompt)
|
||||
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
|
||||
|
||||
yield StreamingInput(prompt=engine_prompt)
|
||||
2
vllm/entrypoints/openai/responses/__init__.py
Normal file
2
vllm/entrypoints/openai/responses/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
141
vllm/entrypoints/openai/responses/api_router.py
Normal file
141
vllm/entrypoints/openai/responses/api_router.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
StreamingResponsesResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def responses(request: Request) -> OpenAIServingResponses | None:
|
||||
return request.app.state.openai_serving_responses
|
||||
|
||||
|
||||
async def _convert_stream_to_sse_events(
|
||||
generator: AsyncGenerator[StreamingResponsesResponse, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Convert the generator to a stream of events in SSE format"""
|
||||
async for event in generator:
|
||||
event_type = getattr(event, "type", "unknown")
|
||||
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
|
||||
event_data = (
|
||||
f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n"
|
||||
)
|
||||
yield event_data
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/responses",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_responses(request: ResponsesRequest, raw_request: Request):
|
||||
handler = responses(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Responses API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.create_responses(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, ResponsesResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(
|
||||
content=_convert_stream_to_sse_events(generator), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/responses/{response_id}")
|
||||
@load_aware_call
|
||||
async def retrieve_responses(
|
||||
response_id: str,
|
||||
raw_request: Request,
|
||||
starting_after: int | None = None,
|
||||
stream: bool | None = False,
|
||||
):
|
||||
handler = responses(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Responses API"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await handler.retrieve_responses(
|
||||
response_id,
|
||||
starting_after=starting_after,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
response = handler.create_error_response(e)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
elif isinstance(response, ResponsesResponse):
|
||||
return JSONResponse(content=response.model_dump())
|
||||
return StreamingResponse(
|
||||
content=_convert_stream_to_sse_events(response), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/responses/{response_id}/cancel")
|
||||
@load_aware_call
|
||||
async def cancel_responses(response_id: str, raw_request: Request):
|
||||
handler = responses(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Responses API"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await handler.cancel_responses(response_id)
|
||||
except Exception as e:
|
||||
response = handler.create_error_response(e)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
918
vllm/entrypoints/openai/responses/context.py
Normal file
918
vllm/entrypoints/openai/responses/context.py
Normal file
@@ -0,0 +1,918 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Final, Union
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.mcp.tool import Tool
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
FunctionCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
get_encoding,
|
||||
get_streamable_parser_for_assistant,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.parser.responses_parser import (
|
||||
get_responses_parser_for_simple_context,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponseRawMessageAndToken,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This is currently needed as the tool type doesn't 1:1 match the
|
||||
# tool namespace, which is what is used to look up the
|
||||
# connection to the tool server
|
||||
_TOOL_NAME_TO_TYPE_MAP = {
|
||||
"browser": "web_search_preview",
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}"
|
||||
)
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnMetrics:
|
||||
"""Tracks token and toolcall details for a single conversation turn."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cached_input_tokens: int = 0,
|
||||
tool_output_tokens: int = 0,
|
||||
) -> None:
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
self.cached_input_tokens = cached_input_tokens
|
||||
self.tool_output_tokens = tool_output_tokens
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset counters for a new turn."""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.cached_input_tokens = 0
|
||||
self.tool_output_tokens = 0
|
||||
|
||||
def copy(self) -> "TurnMetrics":
|
||||
"""Create a copy of this turn's token counts."""
|
||||
return TurnMetrics(
|
||||
self.input_tokens,
|
||||
self.output_tokens,
|
||||
self.cached_input_tokens,
|
||||
self.tool_output_tokens,
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
@abstractmethod
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_tool_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(self) -> list[Message]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_for_completion(self) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
def _create_json_parse_error_messages(
|
||||
last_msg: Message, e: json.JSONDecodeError
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Creates an error message when json parse failed.
|
||||
"""
|
||||
error_msg = (
|
||||
f"Error parsing tool arguments as JSON: {str(e)}. "
|
||||
"Please ensure the tool call arguments are valid JSON and try again."
|
||||
)
|
||||
content = TextContent(text=error_msg)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
"""This is a context that cannot handle MCP tool calls"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
|
||||
# Accumulated final output for streaming mode
|
||||
self._accumulated_text: str = ""
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
self._accumulated_logprobs: list = []
|
||||
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
# todo num_reasoning_tokens is not implemented yet.
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for SimpleContext
|
||||
self.all_turn_metrics = []
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
if not isinstance(output, RequestOutput):
|
||||
raise ValueError("SimpleContext only supports RequestOutput.")
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
# Accumulate text, token_ids, and logprobs for streaming mode
|
||||
delta_output = output.outputs[0]
|
||||
self._accumulated_text += delta_output.text
|
||||
self._accumulated_token_ids.extend(delta_output.token_ids)
|
||||
if delta_output.logprobs is not None:
|
||||
self._accumulated_logprobs.extend(delta_output.logprobs)
|
||||
|
||||
if len(self.input_messages) == 0:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
self.input_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def output_messages(self) -> list[ResponseRawMessageAndToken]:
|
||||
"""Return consolidated output as a single message.
|
||||
|
||||
In streaming mode, text and tokens are accumulated across many deltas.
|
||||
This property returns them as a single entry rather than one per delta.
|
||||
"""
|
||||
if not self._accumulated_text and not self._accumulated_token_ids:
|
||||
return []
|
||||
return [
|
||||
ResponseRawMessageAndToken(
|
||||
message=self._accumulated_text,
|
||||
tokens=list(self._accumulated_token_ids),
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def final_output(self) -> RequestOutput | None:
|
||||
"""Return the final output, with complete text/token_ids/logprobs."""
|
||||
if self.last_output is not None and self.last_output.outputs:
|
||||
assert isinstance(self.last_output, RequestOutput)
|
||||
final_output = copy.copy(self.last_output)
|
||||
# copy inner item to avoid modify last_output
|
||||
final_output.outputs = [replace(item) for item in self.last_output.outputs]
|
||||
final_output.outputs[0].text = self._accumulated_text
|
||||
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
|
||||
if self._accumulated_logprobs:
|
||||
final_output.outputs[0].logprobs = self._accumulated_logprobs
|
||||
return final_output
|
||||
return self.last_output
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class ParsableContext(ConversationContext):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
|
||||
request: ResponsesRequest,
|
||||
available_tools: list[str] | None,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for ParsableContext
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
|
||||
if reasoning_parser_cls is None:
|
||||
raise ValueError("reasoning_parser_cls must be provided.")
|
||||
|
||||
self.parser = get_responses_parser_for_simple_context(
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
tool_parser_cls=tool_parser_cls,
|
||||
)
|
||||
self.tool_parser_cls = tool_parser_cls
|
||||
self.request = request
|
||||
|
||||
self.available_tools = available_tools or []
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
self.parser.process(output.outputs[0])
|
||||
output_token_ids = output.outputs[0].token_ids or []
|
||||
self._accumulated_token_ids.extend(output_token_ids)
|
||||
|
||||
# only store if enable_response_messages is True, save memory
|
||||
if self.request.enable_response_messages:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
if len(self.input_messages) == 0:
|
||||
self.input_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output.outputs[0].text,
|
||||
tokens=output.outputs[0].token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
|
||||
self.parser.response_messages.extend(output)
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
"""Return true if the last message is a builtin tool call
|
||||
that the request has enabled."""
|
||||
last_message = self.parser.response_messages[-1]
|
||||
if last_message.type != "function_call":
|
||||
return False
|
||||
if last_message.name in ("code_interpreter", "python"):
|
||||
return "python" in self.available_tools
|
||||
if last_message.name == "web_search_preview":
|
||||
return "browser" in self.available_tools
|
||||
if last_message.name.startswith("container"):
|
||||
return "container" in self.available_tools
|
||||
return False
|
||||
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
|
||||
) -> list[ResponseInputOutputItem]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result_parsable_context(self)
|
||||
args = json.loads(last_msg.arguments)
|
||||
param = {
|
||||
"code": args["code"],
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"mcpo_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
return [message]
|
||||
|
||||
async def call_search_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
|
||||
) -> list[ResponseInputOutputItem]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result_parsable_context(self)
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.arguments)
|
||||
result = await tool_session.call_tool("search", args)
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"fco_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
return [message]
|
||||
|
||||
async def call_container_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result_parsable_context(self)
|
||||
# tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.arguments)
|
||||
result = await tool_session.call_tool("exec", args)
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"fco_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
return [message]
|
||||
|
||||
async def call_tool(self) -> list[ResponseInputOutputItem]:
|
||||
if not self.parser.response_messages:
|
||||
return []
|
||||
last_msg = self.parser.response_messages[-1]
|
||||
# change this to a mcp_ function call
|
||||
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
|
||||
self.parser.response_messages[-1] = last_msg
|
||||
if last_msg.name == "code_interpreter":
|
||||
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
|
||||
elif last_msg.name == "web_search_preview":
|
||||
return await self.call_search_tool(self._tool_sessions["browser"], last_msg)
|
||||
elif last_msg.name.startswith("container"):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg
|
||||
)
|
||||
return []
|
||||
|
||||
def render_for_completion(self):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name in self._tool_sessions:
|
||||
continue
|
||||
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = (
|
||||
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||
)
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id, headers)
|
||||
)
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info(
|
||||
"Cleaning up tool session for %s", tool_session._client_info
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(
|
||||
*(
|
||||
cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
available_tools: list[str],
|
||||
):
|
||||
self._messages = messages
|
||||
self.finish_reason: str | None = None
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
self.num_tool_output_tokens = 0
|
||||
|
||||
# Turn tracking - replaces multiple individual tracking variables
|
||||
self.current_turn_metrics = TurnMetrics()
|
||||
# Track metrics for all turns
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
self.is_first_turn = True
|
||||
self.first_tok_of_message = True # For streaming support
|
||||
|
||||
def _update_num_reasoning_tokens(self):
|
||||
channel = self.parser.current_channel
|
||||
if channel == "analysis":
|
||||
self.num_reasoning_tokens += 1
|
||||
elif channel == "commentary" and self.parser.current_recipient is not None:
|
||||
# Tool interactions (python/browser/container) are hidden.
|
||||
# Preambles (recipient=None) are visible user text.
|
||||
self.num_reasoning_tokens += 1
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self._update_prefill_token_usage(output)
|
||||
self._update_decode_token_usage(output)
|
||||
# Append current turn to all turn list for next turn's calculations
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# append_output is called only once before tool calling
|
||||
# in non-streaming case
|
||||
# so we can append all the parser messages to _messages
|
||||
output_msgs = self.parser.messages
|
||||
# The responses finish reason is set in the last message
|
||||
self.finish_reason = output.outputs[0].finish_reason
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
output_msgs = output
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
||||
"""Update token usage statistics for the prefill phase of generation.
|
||||
|
||||
The prefill phase processes the input prompt tokens. This method:
|
||||
1. Counts the prompt tokens for this turn
|
||||
2. Calculates tool output tokens for multi-turn conversations
|
||||
3. Updates cached token counts
|
||||
4. Tracks state for next turn calculations
|
||||
|
||||
Tool output tokens are calculated as:
|
||||
current_prompt_tokens - last_turn_prompt_tokens -
|
||||
last_turn_output_tokens
|
||||
This represents tokens added between turns (typically tool responses).
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing prompt token information
|
||||
"""
|
||||
if output.prompt_token_ids is not None:
|
||||
this_turn_input_tokens = len(output.prompt_token_ids)
|
||||
else:
|
||||
this_turn_input_tokens = 0
|
||||
logger.error("RequestOutput appended contains no prompt_token_ids.")
|
||||
|
||||
# Update current turn input tokens
|
||||
self.current_turn_metrics.input_tokens = this_turn_input_tokens
|
||||
self.num_prompt_tokens += this_turn_input_tokens
|
||||
|
||||
# Calculate tool tokens (except on first turn)
|
||||
if self.is_first_turn:
|
||||
self.is_first_turn = False
|
||||
else:
|
||||
previous_turn = self.all_turn_metrics[-1]
|
||||
# start counting tool after first turn
|
||||
# tool tokens = this turn prefill - last turn prefill -
|
||||
# last turn decode
|
||||
this_turn_tool_tokens = (
|
||||
self.current_turn_metrics.input_tokens
|
||||
- previous_turn.input_tokens
|
||||
- previous_turn.output_tokens
|
||||
)
|
||||
|
||||
# Handle negative tool token counts (shouldn't happen in normal
|
||||
# cases)
|
||||
if this_turn_tool_tokens < 0:
|
||||
logger.error(
|
||||
"Negative tool output tokens calculated: %d "
|
||||
"(current_input=%d, previous_input=%d, "
|
||||
"previous_output=%d). Setting to 0.",
|
||||
this_turn_tool_tokens,
|
||||
self.current_turn_metrics.input_tokens,
|
||||
previous_turn.input_tokens,
|
||||
previous_turn.output_tokens,
|
||||
)
|
||||
this_turn_tool_tokens = 0
|
||||
|
||||
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||
self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens
|
||||
|
||||
# Update cached tokens
|
||||
num_cached_token = output.num_cached_tokens
|
||||
if num_cached_token is not None:
|
||||
self.num_cached_tokens += num_cached_token
|
||||
self.current_turn_metrics.cached_input_tokens = num_cached_token
|
||||
|
||||
def _update_decode_token_usage(self, output: RequestOutput) -> int:
|
||||
"""Update token usage statistics for the decode phase of generation.
|
||||
|
||||
The decode phase processes the generated output tokens. This method:
|
||||
1. Counts output tokens from all completion outputs
|
||||
2. Updates the total output token count
|
||||
3. Tracks tokens generated in the current turn
|
||||
|
||||
In streaming mode, this is called for each token generated.
|
||||
In non-streaming mode, this is called once with all output tokens.
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing generated token information
|
||||
|
||||
Returns:
|
||||
int: Number of output tokens processed in this call
|
||||
"""
|
||||
updated_output_token_count = 0
|
||||
if output.outputs:
|
||||
for completion_output in output.outputs:
|
||||
# only keep last round
|
||||
updated_output_token_count += len(completion_output.token_ids)
|
||||
self.num_output_tokens += updated_output_token_count
|
||||
self.current_turn_metrics.output_tokens += updated_output_token_count
|
||||
return updated_output_token_count
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
if recipient is None:
|
||||
return False
|
||||
if recipient.startswith("browser."):
|
||||
return "browser" in self.available_tools
|
||||
if recipient.startswith("python"):
|
||||
return "python" in self.available_tools
|
||||
if recipient.startswith("container."):
|
||||
return "container" in self.available_tools
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
return []
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self._tool_sessions["browser"], last_msg
|
||||
)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg
|
||||
)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg
|
||||
)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
return render_for_completion(self.messages)
|
||||
|
||||
async def call_search_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
"code": last_msg.content[0].text,
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name="python")
|
||||
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT,
|
||||
)
|
||||
]
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = (
|
||||
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||
)
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id, headers)
|
||||
)
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info(
|
||||
"Cleaning up tool session for %s", tool_session._client_info
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(
|
||||
*(
|
||||
cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_output = None
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.encoding = get_encoding()
|
||||
self.last_tok = None
|
||||
self.first_tok_of_message = True
|
||||
self.last_content_delta = None
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
self.last_content_delta = None
|
||||
if self.first_tok_of_message:
|
||||
self._update_prefill_token_usage(output)
|
||||
# Reset self.first_tok_of_message if needed:
|
||||
# if the current token is the last one of the current message
|
||||
# (finished=True), then the next token processed will mark the
|
||||
# beginning of a new message
|
||||
self.first_tok_of_message = output.finished
|
||||
last_delta_text = ""
|
||||
for tok in output.outputs[0].token_ids:
|
||||
self.parser.process(tok)
|
||||
last_delta_text += self.parser.last_content_delta or ""
|
||||
if last_delta_text:
|
||||
self.last_content_delta = last_delta_text
|
||||
self._update_decode_token_usage(output)
|
||||
|
||||
# For streaming, update previous turn when message is complete
|
||||
if output.finished:
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
|
||||
self._messages.extend(
|
||||
self.parser.messages[len(self._messages) - self.num_init_messages :]
|
||||
)
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
msg = output[0]
|
||||
# Sometimes the recipient is not set for tool messages,
|
||||
# so we set it to "assistant"
|
||||
if msg.author.role == Role.TOOL and msg.recipient is None:
|
||||
msg.recipient = "assistant"
|
||||
toks = self.encoding.render(msg)
|
||||
for tok in toks:
|
||||
self.parser.process(tok)
|
||||
self.last_tok = toks[-1]
|
||||
# TODO: add tool_output messages to self._messages
|
||||
|
||||
def is_expecting_start(self) -> bool:
|
||||
return self.parser.state == StreamState.EXPECT_START
|
||||
|
||||
def is_assistant_action_turn(self) -> bool:
|
||||
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
# now this list of tokens as next turn's starting tokens
|
||||
# `<|start|>assistant`,
|
||||
# we need to process them in parser.
|
||||
rendered_tokens = super().render_for_completion()
|
||||
|
||||
last_n = -1
|
||||
to_process = []
|
||||
while rendered_tokens[last_n] != self.last_tok:
|
||||
to_process.append(rendered_tokens[last_n])
|
||||
last_n -= 1
|
||||
for tok in reversed(to_process):
|
||||
self.parser.process(tok)
|
||||
|
||||
return rendered_tokens
|
||||
552
vllm/entrypoints/openai/responses/harmony.py
Normal file
552
vllm/entrypoints/openai/responses/harmony.py
Normal file
@@ -0,0 +1,552 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Harmony ↔ Responses API conversion utilities.
|
||||
|
||||
Handles two directions:
|
||||
1. Response Input → Harmony Messages (input parsing)
|
||||
2. Harmony Messages → Response Output Items (output parsing)
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
ResponseOutputItem,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ActionFind,
|
||||
ActionOpenPage,
|
||||
ActionSearch,
|
||||
ResponseFunctionWebSearch,
|
||||
)
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai_harmony import Author, Message, Role, StreamableParser, TextContent
|
||||
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
BUILTIN_TOOL_TO_MCP_SERVER_LABEL,
|
||||
flatten_chat_text_content,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Private helpers for input parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_harmony_format_message(chat_msg: dict) -> Message:
|
||||
"""Reconstruct a Message from Harmony-format dict,
|
||||
preserving channel, recipient, and content_type."""
|
||||
author_dict = chat_msg["author"]
|
||||
role = author_dict.get("role")
|
||||
name = author_dict.get("name")
|
||||
|
||||
raw_content = chat_msg.get("content", "")
|
||||
if isinstance(raw_content, list):
|
||||
# TODO: Support refusal and non-text content types.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in raw_content]
|
||||
elif isinstance(raw_content, str):
|
||||
contents = [TextContent(text=raw_content)]
|
||||
else:
|
||||
contents = [TextContent(text="")]
|
||||
|
||||
if name:
|
||||
msg = Message.from_author_and_contents(Author.new(Role(role), name), contents)
|
||||
else:
|
||||
msg = Message.from_role_and_contents(Role(role), contents)
|
||||
|
||||
channel = chat_msg.get("channel")
|
||||
if channel:
|
||||
msg = msg.with_channel(channel)
|
||||
recipient = chat_msg.get("recipient")
|
||||
if recipient:
|
||||
msg = msg.with_recipient(recipient)
|
||||
content_type = chat_msg.get("content_type")
|
||||
if content_type:
|
||||
msg = msg.with_content_type(content_type)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def _parse_chat_format_message(chat_msg: dict) -> list[Message]:
|
||||
"""Parse an OpenAI chat-format dict into Harmony messages."""
|
||||
role = chat_msg.get("role")
|
||||
if role is None:
|
||||
raise ValueError(f"Message has no 'role' key: {chat_msg}")
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls")
|
||||
if role == "assistant" and tool_calls:
|
||||
msgs: list[Message] = []
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
if name and not name.startswith("functions."):
|
||||
name = f"functions.{name}"
|
||||
content = chat_msg.get("content", "") or ""
|
||||
content = flatten_chat_text_content(content)
|
||||
# NOTE: .with_recipient("assistant") is required on tool messages
|
||||
# to match parse_chat_input_to_harmony_message behavior and ensure
|
||||
# proper routing in the Harmony protocol.
|
||||
msg = (
|
||||
Message.from_author_and_content(Author.new(Role.TOOL, name), content)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("assistant")
|
||||
)
|
||||
return [msg]
|
||||
|
||||
# Default: user/assistant/system messages
|
||||
content = chat_msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
return [msg]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Public input parsing functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def response_input_to_harmony(
|
||||
response_msg: ResponseInputOutputItem,
|
||||
prev_responses: list[ResponseOutputItem | ResponseReasoningItem],
|
||||
) -> Message:
|
||||
"""Convert a single ResponseInputOutputItem into a Harmony Message."""
|
||||
if not isinstance(response_msg, dict):
|
||||
response_msg = response_msg.model_dump()
|
||||
if "type" not in response_msg or response_msg["type"] == "message":
|
||||
role = response_msg["role"]
|
||||
content = response_msg["content"]
|
||||
# Add prefix for developer messages.
|
||||
# <|start|>developer<|message|># Instructions {instructions}<|end|>
|
||||
text_prefix = "Instructions:\n" if role == "developer" else ""
|
||||
if isinstance(content, str):
|
||||
msg = Message.from_role_and_content(role, text_prefix + content)
|
||||
else:
|
||||
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
if role == "assistant":
|
||||
msg = msg.with_channel("final")
|
||||
elif response_msg["type"] == "function_call_output":
|
||||
call_id = response_msg["call_id"]
|
||||
call_response: ResponseFunctionToolCall | None = None
|
||||
for prev_response in reversed(prev_responses):
|
||||
if (
|
||||
isinstance(prev_response, ResponseFunctionToolCall)
|
||||
and prev_response.call_id == call_id
|
||||
):
|
||||
call_response = prev_response
|
||||
break
|
||||
if call_response is None:
|
||||
raise ValueError(f"No call message found for {call_id}")
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
||||
response_msg["output"],
|
||||
)
|
||||
elif response_msg["type"] == "reasoning":
|
||||
content = response_msg["content"]
|
||||
assert len(content) == 1
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
||||
elif response_msg["type"] == "function_call":
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{response_msg['name']}")
|
||||
msg = msg.with_content_type("json")
|
||||
else:
|
||||
raise ValueError(f"Unknown input type: {response_msg['type']}")
|
||||
return msg
|
||||
|
||||
|
||||
def response_previous_input_to_harmony(chat_msg) -> list[Message]:
|
||||
"""Parse a message from request.previous_input_messages
|
||||
into Harmony messages.
|
||||
|
||||
Supports both OpenAI chat format ({"role": "..."}) and
|
||||
Harmony format ({"author": {"role": "..."}}).
|
||||
"""
|
||||
if not isinstance(chat_msg, dict):
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
if "author" in chat_msg and isinstance(chat_msg.get("author"), dict):
|
||||
return [_parse_harmony_format_message(chat_msg)]
|
||||
|
||||
return _parse_chat_format_message(chat_msg)
|
||||
|
||||
|
||||
def construct_harmony_previous_input_messages(
|
||||
request: ResponsesRequest,
|
||||
) -> list[Message]:
|
||||
"""Build a Harmony message list from request.previous_input_messages.
|
||||
|
||||
Filters out system/developer messages to match OpenAI behavior where
|
||||
instructions are always taken from the most recent Responses API request.
|
||||
"""
|
||||
messages: list[Message] = []
|
||||
if request.previous_input_messages:
|
||||
for message in request.previous_input_messages:
|
||||
# Handle both Message objects and dictionary inputs
|
||||
if isinstance(message, Message):
|
||||
message_role = message.author.role
|
||||
if message_role == Role.SYSTEM or message_role == Role.DEVELOPER:
|
||||
continue
|
||||
messages.append(message)
|
||||
else:
|
||||
harmony_messages = response_previous_input_to_harmony(message)
|
||||
for harmony_msg in harmony_messages:
|
||||
message_role = harmony_msg.author.role
|
||||
if message_role == Role.SYSTEM or message_role == Role.DEVELOPER:
|
||||
continue
|
||||
messages.append(harmony_msg)
|
||||
return messages
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Private helpers for output parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem:
|
||||
"""Parse browser tool calls (search, open, find) into web search items."""
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
|
||||
# Parse JSON args (with retry detection)
|
||||
try:
|
||||
browser_call = json.loads(content.text)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Invalid JSON in browser tool call, using error placeholder: %s",
|
||||
content.text,
|
||||
)
|
||||
json_retry_output_message = (
|
||||
f"Invalid JSON args, caught and retried: {content.text}"
|
||||
)
|
||||
browser_call = {
|
||||
"query": json_retry_output_message,
|
||||
"url": json_retry_output_message,
|
||||
"pattern": json_retry_output_message,
|
||||
}
|
||||
|
||||
# Create appropriate action based on recipient
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search"
|
||||
)
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(
|
||||
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
|
||||
)
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(
|
||||
pattern=browser_call.get("pattern", ""),
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
|
||||
return ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
|
||||
|
||||
def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
|
||||
"""Parse function calls into function tool call items."""
|
||||
function_name = recipient.split(".")[-1]
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"fc_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_reasoning(message: Message) -> list[ResponseOutputItem]:
|
||||
"""Parse reasoning/analysis content into reasoning items."""
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(text=content.text, type="reasoning_text")
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_final_message(message: Message) -> ResponseOutputItem:
|
||||
"""Parse final channel messages into output message items."""
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
return ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
|
||||
def _parse_mcp_recipient(recipient: str) -> tuple[str, str]:
|
||||
"""Parse MCP recipient into (server_label, tool_name).
|
||||
|
||||
For dotted recipients like "repo_browser.list":
|
||||
- server_label: "repo_browser" (namespace/server)
|
||||
- tool_name: "list" (specific tool)
|
||||
|
||||
For simple recipients like "filesystem":
|
||||
- server_label: "filesystem"
|
||||
- tool_name: "filesystem"
|
||||
"""
|
||||
if "." in recipient:
|
||||
server_label = recipient.split(".")[0]
|
||||
tool_name = recipient.split(".")[-1]
|
||||
else:
|
||||
server_label = recipient
|
||||
tool_name = recipient
|
||||
return server_label, tool_name
|
||||
|
||||
|
||||
def _parse_mcp_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
|
||||
"""Parse MCP calls into MCP call items."""
|
||||
# Handle built-in tools that need server_label mapping
|
||||
if recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
|
||||
server_label = BUILTIN_TOOL_TO_MCP_SERVER_LABEL[recipient]
|
||||
tool_name = recipient
|
||||
else:
|
||||
server_label, tool_name = _parse_mcp_recipient(recipient)
|
||||
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
response_item = McpCall(
|
||||
arguments=content.text,
|
||||
type="mcp_call",
|
||||
name=tool_name,
|
||||
server_label=server_label,
|
||||
id=f"mcp_{random_uuid()}",
|
||||
status="completed",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_message_no_recipient(
|
||||
message: Message,
|
||||
) -> list[ResponseOutputItem]:
|
||||
"""Parse a Harmony message with no recipient based on its channel."""
|
||||
if message.channel == "analysis":
|
||||
return _parse_reasoning(message)
|
||||
|
||||
if message.channel in ("commentary", "final"):
|
||||
# Per Harmony format, preambles (commentary with no recipient) and
|
||||
# final channel content are both intended to be shown to end-users.
|
||||
# See: https://cookbook.openai.com/articles/openai-harmony
|
||||
return [_parse_final_message(message)]
|
||||
|
||||
raise ValueError(f"Unknown channel: {message.channel}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Public output parsing functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def harmony_to_response_output(message: Message) -> list[ResponseOutputItem]:
|
||||
"""Parse a Harmony message into a list of output response items.
|
||||
|
||||
This is the main dispatcher that routes based on channel and recipient.
|
||||
"""
|
||||
if message.author.role != "assistant":
|
||||
# This is a message from a tool to the assistant (e.g., search result).
|
||||
# Don't include it in the final output for now. This aligns with
|
||||
# OpenAI's behavior on models like o4-mini.
|
||||
return []
|
||||
|
||||
output_items: list[ResponseOutputItem] = []
|
||||
recipient = message.recipient
|
||||
|
||||
if recipient is not None:
|
||||
# Browser tool calls (browser.search, browser.open, browser.find)
|
||||
if recipient.startswith("browser."):
|
||||
output_items.append(_parse_browser_tool_call(message, recipient))
|
||||
|
||||
# Function calls (should only happen on commentary channel)
|
||||
elif message.channel == "commentary" and recipient.startswith("functions."):
|
||||
output_items.extend(_parse_function_call(message, recipient))
|
||||
|
||||
# Built-in MCP tools (python, browser, container)
|
||||
elif recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
|
||||
output_items.extend(_parse_reasoning(message))
|
||||
|
||||
# All other recipients are MCP calls
|
||||
else:
|
||||
output_items.extend(_parse_mcp_call(message, recipient))
|
||||
|
||||
# No recipient - handle based on channel for non-tool messages
|
||||
else:
|
||||
output_items.extend(_parse_message_no_recipient(message))
|
||||
|
||||
return output_items
|
||||
|
||||
|
||||
def parser_state_to_response_output(
|
||||
parser: StreamableParser,
|
||||
) -> list[ResponseOutputItem]:
|
||||
"""Extract in-progress response items from incomplete parser state.
|
||||
|
||||
Called when the parser has buffered content that hasn't formed a
|
||||
complete message yet (e.g., generation was cut short).
|
||||
"""
|
||||
if not parser.current_content:
|
||||
return []
|
||||
if parser.current_role != Role.ASSISTANT:
|
||||
return []
|
||||
current_recipient = parser.current_recipient
|
||||
if current_recipient is not None and current_recipient.startswith("browser."):
|
||||
return []
|
||||
|
||||
if current_recipient and parser.current_channel in ("commentary", "analysis"):
|
||||
if current_recipient.startswith("functions."):
|
||||
rid = random_uuid()
|
||||
return [
|
||||
ResponseFunctionToolCall(
|
||||
arguments=parser.current_content,
|
||||
call_id=f"call_{rid}",
|
||||
type="function_call",
|
||||
name=current_recipient.split(".")[-1],
|
||||
id=f"fc_{rid}",
|
||||
status="in_progress",
|
||||
)
|
||||
]
|
||||
# Built-in MCP tools (python, browser, container)
|
||||
elif current_recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
|
||||
return [
|
||||
ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
]
|
||||
# All other recipients are MCP calls
|
||||
else:
|
||||
rid = random_uuid()
|
||||
server_label, tool_name = _parse_mcp_recipient(current_recipient)
|
||||
return [
|
||||
McpCall(
|
||||
arguments=parser.current_content,
|
||||
type="mcp_call",
|
||||
name=tool_name,
|
||||
server_label=server_label,
|
||||
id=f"mcp_{rid}",
|
||||
status="in_progress",
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "commentary":
|
||||
# Per Harmony format, preambles (commentary with no recipient) are
|
||||
# intended to be shown to end-users, unlike analysis channel content.
|
||||
output_text = ResponseOutputText(
|
||||
text=parser.current_content,
|
||||
annotations=[],
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
return [
|
||||
ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
role="assistant",
|
||||
status="incomplete",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "analysis":
|
||||
return [
|
||||
ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "final":
|
||||
output_text = ResponseOutputText(
|
||||
text=parser.current_content,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
role="assistant",
|
||||
# if the parser still has messages (ie if the generator got cut
|
||||
# abruptly), this should be incomplete
|
||||
status="incomplete",
|
||||
type="message",
|
||||
)
|
||||
return [text_item]
|
||||
|
||||
return []
|
||||
641
vllm/entrypoints/openai/responses/protocol.py
Normal file
641
vllm/entrypoints/openai/responses/protocol.py
Normal file
@@ -0,0 +1,641 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from openai.types.responses import (
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInputItemParam,
|
||||
ResponseMcpCallArgumentsDeltaEvent,
|
||||
ResponseMcpCallArgumentsDoneEvent,
|
||||
ResponseMcpCallCompletedEvent,
|
||||
ResponseMcpCallInProgressEvent,
|
||||
ResponseOutputItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponsePrompt,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseStatus,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
)
|
||||
from openai.types.responses import (
|
||||
ResponseCompletedEvent as OpenAIResponseCompletedEvent,
|
||||
)
|
||||
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
|
||||
from openai.types.responses import (
|
||||
ResponseInProgressEvent as OpenAIResponseInProgressEvent,
|
||||
)
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
|
||||
# Backward compatibility for OpenAI client versions
|
||||
try: # For older openai versions (< 1.100.0)
|
||||
from openai.types.responses import ResponseTextConfig
|
||||
except ImportError: # For newer openai versions (>= 1.100.0)
|
||||
from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
|
||||
|
||||
from openai.types.responses.response import IncompleteDetails, ToolChoice
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai.types.shared import Metadata, Reasoning
|
||||
from pydantic import (
|
||||
Field,
|
||||
ValidationError,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import (
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class InputTokensDetails(OpenAIBaseModel):
|
||||
cached_tokens: int
|
||||
input_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
cached_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OutputTokensDetails(OpenAIBaseModel):
|
||||
reasoning_tokens: int = 0
|
||||
tool_output_tokens: int = 0
|
||||
output_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseUsage(OpenAIBaseModel):
|
||||
input_tokens: int
|
||||
input_tokens_details: InputTokensDetails
|
||||
output_tokens: int
|
||||
output_tokens_details: OutputTokensDetails
|
||||
total_tokens: int
|
||||
|
||||
|
||||
def serialize_message(msg):
|
||||
"""
|
||||
Serializes a single message
|
||||
"""
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
elif hasattr(msg, "to_dict"):
|
||||
return msg.to_dict()
|
||||
else:
|
||||
# fallback to pyandic dump
|
||||
return msg.model_dump_json()
|
||||
|
||||
|
||||
def serialize_messages(msgs):
|
||||
"""
|
||||
Serializes multiple messages
|
||||
"""
|
||||
return [serialize_message(msg) for msg in msgs] if msgs else None
|
||||
|
||||
|
||||
class ResponseRawMessageAndToken(OpenAIBaseModel):
|
||||
"""Class to show the raw message.
|
||||
If message / tokens diverge, tokens is the source of truth"""
|
||||
|
||||
message: str
|
||||
tokens: list[int]
|
||||
type: Literal["raw_message_tokens"] = "raw_message_tokens"
|
||||
|
||||
|
||||
ResponseInputOutputMessage: TypeAlias = (
|
||||
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
|
||||
)
|
||||
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
|
||||
|
||||
|
||||
class ResponsesRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/responses/create
|
||||
background: bool | None = False
|
||||
include: (
|
||||
list[
|
||||
Literal[
|
||||
"code_interpreter_call.outputs",
|
||||
"computer_call_output.output.image_url",
|
||||
"file_search_call.results",
|
||||
"message.input_image.image_url",
|
||||
"message.output_text.logprobs",
|
||||
"reasoning.encrypted_content",
|
||||
],
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
input: str | list[ResponseInputOutputItem]
|
||||
instructions: str | None = None
|
||||
max_output_tokens: int | None = None
|
||||
max_tool_calls: int | None = None
|
||||
metadata: Metadata | None = None
|
||||
model: str | None = None
|
||||
logit_bias: dict[str, float] | None = None
|
||||
parallel_tool_calls: bool | None = True
|
||||
previous_response_id: str | None = None
|
||||
prompt: ResponsePrompt | None = None
|
||||
reasoning: Reasoning | None = None
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
|
||||
store: bool | None = True
|
||||
stream: bool | None = False
|
||||
temperature: float | None = None
|
||||
text: ResponseTextConfig | None = None
|
||||
tool_choice: ToolChoice = "auto"
|
||||
tools: list[Tool] = Field(default_factory=list)
|
||||
top_logprobs: int | None = 0
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
truncation: Literal["auto", "disabled"] | None = "disabled"
|
||||
user: str | None = None
|
||||
skip_special_tokens: bool = True
|
||||
include_stop_str_in_output: bool = False
|
||||
prompt_cache_key: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A key that was used to read from or write to the prompt cache."
|
||||
"Note: This field has not been implemented yet "
|
||||
"and vLLM will ignore it."
|
||||
),
|
||||
)
|
||||
|
||||
# --8<-- [start:responses-extra-params]
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"resp_{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
|
||||
enable_response_messages: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Dictates whether or not to return messages as part of the "
|
||||
"response object. Currently only supported for non-background."
|
||||
),
|
||||
)
|
||||
# similar to input_messages / output_messages in ResponsesResponse
|
||||
# we take in previous_input_messages (ie in harmony format)
|
||||
# this cannot be used in conjunction with previous_response_id
|
||||
# TODO: consider supporting non harmony messages as well
|
||||
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
|
||||
structured_outputs: StructuredOutputsParams | None = Field(
|
||||
default=None,
|
||||
description="Additional kwargs for structured outputs",
|
||||
)
|
||||
|
||||
repetition_penalty: float | None = None
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: str | list[str] | None = []
|
||||
ignore_eos: bool = False
|
||||
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with (list of) string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
from .utils import should_continue_final_message
|
||||
|
||||
# Check if we should continue the final message (partial completion)
|
||||
# This enables Anthropic-style partial message completion where the
|
||||
# user provides an incomplete assistant message to continue from.
|
||||
continue_final = should_continue_final_message(self.input)
|
||||
|
||||
reasoning = self.reasoning
|
||||
|
||||
return ChatParams(
|
||||
chat_template=default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs( # To remove unset values
|
||||
{},
|
||||
dict(
|
||||
add_generation_prompt=not continue_final,
|
||||
continue_final_message=continue_final,
|
||||
reasoning_effort=None if reasoning is None else reasoning.effort,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=self.max_output_tokens or 0,
|
||||
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_output_tokens",
|
||||
)
|
||||
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> SamplingParams:
|
||||
if self.max_output_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
else:
|
||||
max_tokens = min(self.max_output_tokens, default_max_tokens)
|
||||
|
||||
default_sampling_params = default_sampling_params or {}
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||
)
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||
)
|
||||
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get("repetition_penalty", 1.0)
|
||||
|
||||
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||
|
||||
# Structured output
|
||||
structured_outputs = self.structured_outputs
|
||||
|
||||
# Also check text.format for OpenAI-style json_schema
|
||||
if self.text is not None and self.text.format is not None:
|
||||
if structured_outputs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both structured_outputs and text.format"
|
||||
)
|
||||
response_format = self.text.format
|
||||
if (
|
||||
response_format.type == "json_schema"
|
||||
and response_format.schema_ is not None
|
||||
):
|
||||
structured_outputs = StructuredOutputsParams(
|
||||
json=response_format.schema_ # type: ignore[call-arg]
|
||||
# --follow-imports skip hides the class definition but also hides
|
||||
# multiple third party conflicts, so best of both evils
|
||||
)
|
||||
|
||||
stop = self.stop if self.stop else []
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=self.seed,
|
||||
ignore_eos=self.ignore_eos,
|
||||
output_kind=(
|
||||
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
|
||||
),
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=self.logit_bias,
|
||||
extra_args=self.vllm_xargs or {},
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def is_include_output_logprobs(self) -> bool:
|
||||
"""Check if the request includes output logprobs."""
|
||||
if self.include is None:
|
||||
return False
|
||||
return (
|
||||
isinstance(self.include, list)
|
||||
and "message.output_text.logprobs" in self.include
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_background(cls, data):
|
||||
if not data.get("background"):
|
||||
return data
|
||||
if not data.get("store", True):
|
||||
raise ValueError("background can only be used when `store` is true")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_prompt(cls, data):
|
||||
if data.get("prompt") is not None:
|
||||
raise VLLMValidationError(
|
||||
"prompt template is not supported", parameter="prompt"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None and (
|
||||
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
def function_call_parsing(cls, data):
|
||||
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
|
||||
This ensures Pydantic can properly resolve union types in the input field.
|
||||
Function calls provided as dicts are converted to ResponseFunctionToolCall
|
||||
objects before validation, while invalid structures are left for Pydantic
|
||||
to reject with appropriate error messages.
|
||||
"""
|
||||
|
||||
input_data = data.get("input")
|
||||
|
||||
# Early return for None, strings, or bytes
|
||||
# (strings are iterable but shouldn't be processed)
|
||||
if input_data is None or isinstance(input_data, (str, bytes)):
|
||||
return data
|
||||
|
||||
# Convert iterators (like ValidatorIterator) to list
|
||||
if not isinstance(input_data, list):
|
||||
try:
|
||||
input_data = list(input_data)
|
||||
except TypeError:
|
||||
# Not iterable, leave as-is for Pydantic to handle
|
||||
return data
|
||||
|
||||
processed_input = []
|
||||
for item in input_data:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
try:
|
||||
processed_input.append(ResponseFunctionToolCall(**item))
|
||||
except ValidationError:
|
||||
# Let Pydantic handle validation for malformed function calls
|
||||
logger.debug(
|
||||
"Failed to parse function_call to ResponseFunctionToolCall, "
|
||||
"leaving for Pydantic validation"
|
||||
)
|
||||
processed_input.append(item)
|
||||
else:
|
||||
processed_input.append(item)
|
||||
|
||||
data["input"] = processed_input
|
||||
return data
|
||||
|
||||
|
||||
class ResponsesResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
# error: Optional[ResponseError] = None
|
||||
incomplete_details: IncompleteDetails | None = None
|
||||
instructions: str | None = None
|
||||
metadata: Metadata | None = None
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: list[ResponseOutputItem]
|
||||
parallel_tool_calls: bool
|
||||
temperature: float
|
||||
tool_choice: ToolChoice
|
||||
tools: list[Tool]
|
||||
top_p: float
|
||||
background: bool
|
||||
max_output_tokens: int
|
||||
max_tool_calls: int | None = None
|
||||
previous_response_id: str | None = None
|
||||
prompt: ResponsePrompt | None = None
|
||||
reasoning: Reasoning | None = None
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
|
||||
status: ResponseStatus
|
||||
text: ResponseTextConfig | None = None
|
||||
top_logprobs: int | None = None
|
||||
truncation: Literal["auto", "disabled"]
|
||||
usage: ResponseUsage | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:responses-response-extra-params]
|
||||
# These are populated when enable_response_messages is set to True
|
||||
# NOTE: custom serialization is needed
|
||||
# see serialize_input_messages and serialize_output_messages
|
||||
input_messages: ResponseInputOutputMessage | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If enable_response_messages, we can show raw token input to model."
|
||||
),
|
||||
)
|
||||
output_messages: ResponseInputOutputMessage | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If enable_response_messages, we can show raw token output of model."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:responses-response-extra-params]
|
||||
|
||||
# NOTE: openAI harmony doesn't serialize TextContent properly,
|
||||
# TODO: this fixes for TextContent, but need to verify for tools etc
|
||||
# https://github.com/openai/harmony/issues/78
|
||||
@field_serializer("output_messages", when_used="json")
|
||||
def serialize_output_messages(self, msgs, _info):
|
||||
return serialize_messages(msgs)
|
||||
|
||||
# NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
|
||||
# https://github.com/openai/harmony/issues/78
|
||||
@field_serializer("input_messages", when_used="json")
|
||||
def serialize_input_messages(self, msgs, _info):
|
||||
return serialize_messages(msgs)
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
model_name: str,
|
||||
created_time: int,
|
||||
output: list[ResponseOutputItem],
|
||||
status: ResponseStatus,
|
||||
usage: ResponseUsage | None = None,
|
||||
input_messages: ResponseInputOutputMessage | None = None,
|
||||
output_messages: ResponseInputOutputMessage | None = None,
|
||||
) -> "ResponsesResponse":
|
||||
incomplete_details: IncompleteDetails | None = None
|
||||
if status == "incomplete":
|
||||
incomplete_details = IncompleteDetails(reason="max_output_tokens")
|
||||
# TODO: implement the other reason for incomplete_details,
|
||||
# which is content_filter
|
||||
# incomplete_details = IncompleteDetails(reason='content_filter')
|
||||
return cls(
|
||||
id=request.request_id,
|
||||
created_at=created_time,
|
||||
incomplete_details=incomplete_details,
|
||||
instructions=request.instructions,
|
||||
metadata=request.metadata,
|
||||
model=model_name,
|
||||
output=output,
|
||||
input_messages=input_messages,
|
||||
output_messages=output_messages,
|
||||
parallel_tool_calls=request.parallel_tool_calls,
|
||||
temperature=sampling_params.temperature,
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
top_p=sampling_params.top_p,
|
||||
background=request.background,
|
||||
max_output_tokens=sampling_params.max_tokens,
|
||||
max_tool_calls=request.max_tool_calls,
|
||||
previous_response_id=request.previous_response_id,
|
||||
prompt=request.prompt,
|
||||
reasoning=request.reasoning,
|
||||
service_tier=request.service_tier,
|
||||
status=status,
|
||||
text=request.text,
|
||||
top_logprobs=sampling_params.logprobs,
|
||||
truncation=request.truncation,
|
||||
user=request.user,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
# TODO: this code can be removed once
|
||||
# https://github.com/openai/openai-python/issues/2634 has been resolved
|
||||
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
|
||||
content_index: int
|
||||
"""The index of the content part that is done."""
|
||||
|
||||
item_id: str
|
||||
"""The ID of the output item that the content part was added to."""
|
||||
|
||||
output_index: int
|
||||
"""The index of the output item that the content part was added to."""
|
||||
|
||||
part: ResponseReasoningTextContent
|
||||
"""The content part that is done."""
|
||||
|
||||
sequence_number: int
|
||||
"""The sequence number of this event."""
|
||||
|
||||
type: Literal["response.reasoning_part.done"]
|
||||
"""The type of the event. Always `response.reasoning_part.done`."""
|
||||
|
||||
|
||||
# TODO: this code can be removed once
|
||||
# https://github.com/openai/openai-python/issues/2634 has been resolved
|
||||
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
|
||||
content_index: int
|
||||
"""The index of the content part that is done."""
|
||||
|
||||
item_id: str
|
||||
"""The ID of the output item that the content part was added to."""
|
||||
|
||||
output_index: int
|
||||
"""The index of the output item that the content part was added to."""
|
||||
|
||||
part: ResponseReasoningTextContent
|
||||
"""The content part that is done."""
|
||||
|
||||
sequence_number: int
|
||||
"""The sequence number of this event."""
|
||||
|
||||
type: Literal["response.reasoning_part.added"]
|
||||
"""The type of the event. Always `response.reasoning_part.added`."""
|
||||
|
||||
|
||||
# vLLM Streaming Events
|
||||
# Note: we override the response type with the vLLM ResponsesResponse type
|
||||
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
|
||||
response: ResponsesResponse # type: ignore[override]
|
||||
|
||||
|
||||
class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
|
||||
response: ResponsesResponse # type: ignore[override]
|
||||
|
||||
|
||||
class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
|
||||
response: ResponsesResponse # type: ignore[override]
|
||||
|
||||
|
||||
StreamingResponsesResponse: TypeAlias = (
|
||||
ResponseCreatedEvent
|
||||
| ResponseInProgressEvent
|
||||
| ResponseCompletedEvent
|
||||
| ResponseOutputItemAddedEvent
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseContentPartAddedEvent
|
||||
| ResponseContentPartDoneEvent
|
||||
| ResponseReasoningTextDeltaEvent
|
||||
| ResponseReasoningTextDoneEvent
|
||||
| ResponseReasoningPartAddedEvent
|
||||
| ResponseReasoningPartDoneEvent
|
||||
| ResponseCodeInterpreterCallInProgressEvent
|
||||
| ResponseCodeInterpreterCallCodeDeltaEvent
|
||||
| ResponseWebSearchCallInProgressEvent
|
||||
| ResponseWebSearchCallSearchingEvent
|
||||
| ResponseWebSearchCallCompletedEvent
|
||||
| ResponseCodeInterpreterCallCodeDoneEvent
|
||||
| ResponseCodeInterpreterCallInterpretingEvent
|
||||
| ResponseCodeInterpreterCallCompletedEvent
|
||||
| ResponseMcpCallArgumentsDeltaEvent
|
||||
| ResponseMcpCallArgumentsDoneEvent
|
||||
| ResponseMcpCallInProgressEvent
|
||||
| ResponseMcpCallCompletedEvent
|
||||
)
|
||||
1723
vllm/entrypoints/openai/responses/serving.py
Normal file
1723
vllm/entrypoints/openai/responses/serving.py
Normal file
File diff suppressed because it is too large
Load Diff
798
vllm/entrypoints/openai/responses/streaming_events.py
Normal file
798
vllm/entrypoints/openai/responses/streaming_events.py
Normal file
@@ -0,0 +1,798 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Streaming SSE event builders for the Responses API.
|
||||
|
||||
Pure functions that translate streaming state + delta data into
|
||||
OpenAI Response API SSE events. Used by the streaming event
|
||||
processors in serving.py.
|
||||
|
||||
The file is organized as:
|
||||
1. StreamingState dataclass + utility helpers
|
||||
2. Shared leaf helpers — delta events (take plain strings, no context)
|
||||
3. Shared leaf helpers — done events (take plain strings, no context)
|
||||
4. Harmony-specific dispatchers (route ctx/previous_item → leaf helpers)
|
||||
5. Harmony-specific tool lifecycle helpers
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseCodeInterpreterToolCallParam,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionWebSearch,
|
||||
ResponseMcpCallArgumentsDeltaEvent,
|
||||
ResponseMcpCallArgumentsDoneEvent,
|
||||
ResponseMcpCallCompletedEvent,
|
||||
ResponseMcpCallInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
response_function_web_search,
|
||||
)
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai_harmony import Message as HarmonyMessage
|
||||
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.responses.context import StreamingHarmonyContext
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
StreamingResponsesResponse,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = {
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
"browser": "web_search_preview",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_mcp_name_label(recipient: str) -> tuple[str, str]:
|
||||
"""Resolve MCP tool name and server label from a recipient string.
|
||||
|
||||
- ``mcp.*`` recipients: strip prefix, use the bare name as both
|
||||
name and server_label.
|
||||
- Everything else: use the recipient as the name and look up the
|
||||
server_label in TOOL_NAME_TO_MCP_SERVER_LABEL.
|
||||
"""
|
||||
if recipient.startswith("mcp."):
|
||||
name = recipient[len("mcp.") :]
|
||||
return name, name
|
||||
return recipient, TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingState:
|
||||
"""Mutable state for streaming event processing."""
|
||||
|
||||
current_content_index: int = -1
|
||||
current_output_index: int = 0
|
||||
current_item_id: str = ""
|
||||
current_call_id: str = ""
|
||||
sent_output_item_added: bool = False
|
||||
is_first_function_call_delta: bool = False
|
||||
|
||||
def reset_for_new_item(self) -> None:
|
||||
"""Reset state when expecting a new output item."""
|
||||
self.current_output_index += 1
|
||||
self.sent_output_item_added = False
|
||||
self.is_first_function_call_delta = False
|
||||
self.current_call_id = ""
|
||||
|
||||
|
||||
def is_mcp_tool_by_namespace(recipient: str | None) -> bool:
|
||||
"""
|
||||
Determine if a tool call is an MCP tool based on recipient prefix.
|
||||
|
||||
- Tools starting with "functions." are function calls
|
||||
- Everything else is an MCP tool
|
||||
"""
|
||||
if recipient is None:
|
||||
return False
|
||||
|
||||
# Function calls have "functions." prefix
|
||||
# Everything else is an MCP tool
|
||||
return not recipient.startswith("functions.")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Shared leaf helpers — delta events
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_text_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for text content delta streaming."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"msg_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseOutputMessage(
|
||||
id=state.current_item_id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[],
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
state.current_content_index += 1
|
||||
events.append(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
content_index=state.current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseTextDeltaEvent(
|
||||
type="response.output_text.delta",
|
||||
sequence_number=-1,
|
||||
content_index=state.current_content_index,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
# TODO, use logprobs from ctx.last_request_output
|
||||
logprobs=[],
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_reasoning_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for reasoning text delta streaming."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"msg_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
id=state.current_item_id,
|
||||
summary=[],
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
state.current_content_index += 1
|
||||
events.append(
|
||||
ResponseReasoningPartAddedEvent(
|
||||
type="response.reasoning_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
content_index=state.current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text="",
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseReasoningTextDeltaEvent(
|
||||
type="response.reasoning_text.delta",
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
delta=delta,
|
||||
sequence_number=-1,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_function_call_delta_events(
|
||||
delta: str,
|
||||
function_name: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for function call argument deltas."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if state.is_first_function_call_delta is False:
|
||||
state.is_first_function_call_delta = True
|
||||
state.current_item_id = f"fc_{random_uuid()}"
|
||||
state.current_call_id = f"call_{random_uuid()}"
|
||||
tool_call_item = ResponseFunctionToolCall(
|
||||
name=function_name,
|
||||
type="function_call",
|
||||
id=state.current_item_id,
|
||||
call_id=state.current_call_id,
|
||||
arguments="",
|
||||
status="in_progress",
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=tool_call_item,
|
||||
)
|
||||
)
|
||||
# Always emit the delta (including on first call)
|
||||
events.append(
|
||||
ResponseFunctionCallArgumentsDeltaEvent(
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
type="response.function_call_arguments.delta",
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_mcp_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
recipient: str,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for MCP tool delta streaming."""
|
||||
name, server_label = _resolve_mcp_name_label(recipient)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"mcp_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=McpCall(
|
||||
type="mcp_call",
|
||||
id=state.current_item_id,
|
||||
name=name,
|
||||
arguments="",
|
||||
server_label=server_label,
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseMcpCallInProgressEvent(
|
||||
type="response.mcp_call.in_progress",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseMcpCallArgumentsDeltaEvent(
|
||||
type="response.mcp_call_arguments.delta",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_code_interpreter_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for code interpreter delta streaming."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"tool_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseCodeInterpreterToolCallParam(
|
||||
type="code_interpreter_call",
|
||||
id=state.current_item_id,
|
||||
code=None,
|
||||
container_id="auto",
|
||||
outputs=None,
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallInProgressEvent(
|
||||
type="response.code_interpreter_call.in_progress",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent(
|
||||
type="response.code_interpreter_call_code.delta",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Shared leaf helpers — done events
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_text_output_done_events(
|
||||
text: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when a final text output item completes."""
|
||||
text_content = ResponseOutputText(
|
||||
type="output_text",
|
||||
text=text,
|
||||
annotations=[],
|
||||
)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseTextDoneEvent(
|
||||
type="response.output_text.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
text=text,
|
||||
logprobs=[],
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseContentPartDoneEvent(
|
||||
type="response.content_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
part=text_content,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseOutputMessage(
|
||||
id=state.current_item_id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[text_content],
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_reasoning_done_events(
|
||||
text: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when a reasoning (analysis) item completes."""
|
||||
content = ResponseReasoningTextContent(
|
||||
text=text,
|
||||
type="reasoning_text",
|
||||
)
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
content=[content],
|
||||
status="completed",
|
||||
id=state.current_item_id,
|
||||
summary=[],
|
||||
)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseReasoningTextDoneEvent(
|
||||
type="response.reasoning_text.done",
|
||||
item_id=state.current_item_id,
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
text=text,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseReasoningPartDoneEvent(
|
||||
type="response.reasoning_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
part=content,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=reasoning_item,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_function_call_done_events(
|
||||
function_name: str,
|
||||
arguments: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when a function call completes."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseFunctionCallArgumentsDoneEvent(
|
||||
type="response.function_call_arguments.done",
|
||||
arguments=arguments,
|
||||
name=function_name,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
)
|
||||
)
|
||||
function_call_item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
arguments=arguments,
|
||||
name=function_name,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
call_id=state.current_call_id,
|
||||
status="completed",
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=function_call_item,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_mcp_completion_events(
|
||||
recipient: str,
|
||||
arguments: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when an MCP tool call completes."""
|
||||
name, server_label = _resolve_mcp_name_label(recipient)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseMcpCallArgumentsDoneEvent(
|
||||
type="response.mcp_call_arguments.done",
|
||||
arguments=arguments,
|
||||
name=name,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseMcpCallCompletedEvent(
|
||||
type="response.mcp_call.completed",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=McpCall(
|
||||
type="mcp_call",
|
||||
arguments=arguments,
|
||||
name=name,
|
||||
id=state.current_item_id,
|
||||
server_label=server_label,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Harmony-specific dispatchers
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_content_delta_events(
|
||||
ctx: StreamingHarmonyContext,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for content delta streaming based on channel type.
|
||||
|
||||
This is a Harmony-specific dispatcher that extracts values from the
|
||||
Harmony context and delegates to shared leaf helpers.
|
||||
"""
|
||||
delta = ctx.last_content_delta
|
||||
if not delta:
|
||||
return []
|
||||
|
||||
channel = ctx.parser.current_channel
|
||||
recipient = ctx.parser.current_recipient
|
||||
|
||||
if channel in ("final", "commentary") and recipient is None:
|
||||
# Preambles (commentary with no recipient) and final messages
|
||||
# are both user-visible text.
|
||||
return emit_text_delta_events(delta, state)
|
||||
elif channel == "analysis" and recipient is None:
|
||||
return emit_reasoning_delta_events(delta, state)
|
||||
# built-in tools will be triggered on the analysis channel
|
||||
# However, occasionally built-in tools will
|
||||
# still be output to commentary.
|
||||
elif channel in ("commentary", "analysis") and recipient is not None:
|
||||
if recipient.startswith("functions."):
|
||||
function_name = recipient[len("functions.") :]
|
||||
return emit_function_call_delta_events(delta, function_name, state)
|
||||
elif recipient == "python":
|
||||
return emit_code_interpreter_delta_events(delta, state)
|
||||
elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
|
||||
return emit_mcp_delta_events(delta, state, recipient)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def emit_previous_item_done_events(
|
||||
previous_item: HarmonyMessage,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit done events for the previous item when expecting a new start.
|
||||
|
||||
This is a Harmony-specific dispatcher that extracts values from the
|
||||
Harmony parser's message object and delegates to shared leaf helpers.
|
||||
"""
|
||||
text = previous_item.content[0].text
|
||||
if previous_item.recipient is not None:
|
||||
# Deal with tool call
|
||||
if previous_item.recipient.startswith("functions."):
|
||||
function_name = previous_item.recipient[len("functions.") :]
|
||||
return emit_function_call_done_events(function_name, text, state)
|
||||
elif previous_item.recipient == "python":
|
||||
return emit_code_interpreter_completion_events(previous_item, state)
|
||||
elif (
|
||||
is_mcp_tool_by_namespace(previous_item.recipient)
|
||||
and state.current_item_id is not None
|
||||
and state.current_item_id.startswith("mcp_")
|
||||
):
|
||||
return emit_mcp_completion_events(previous_item.recipient, text, state)
|
||||
elif previous_item.channel == "analysis":
|
||||
return emit_reasoning_done_events(text, state)
|
||||
elif previous_item.channel in ("commentary", "final"):
|
||||
# Preambles (commentary with no recipient) and final messages
|
||||
# are both user-visible text.
|
||||
return emit_text_output_done_events(text, state)
|
||||
return []
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Harmony-specific tool lifecycle helpers
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_browser_tool_events(
|
||||
previous_item: HarmonyMessage,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for browser tool calls (web search)."""
|
||||
function_name = previous_item.recipient[len("browser.") :]
|
||||
parsed_args = json.loads(previous_item.content[0].text)
|
||||
action = None
|
||||
|
||||
if function_name == "search":
|
||||
action = response_function_web_search.ActionSearch(
|
||||
type="search",
|
||||
query=parsed_args["query"],
|
||||
)
|
||||
elif function_name == "open":
|
||||
action = response_function_web_search.ActionOpenPage(
|
||||
type="open_page",
|
||||
# TODO: translate to url
|
||||
url=f"cursor:{parsed_args.get('cursor', '')}",
|
||||
)
|
||||
elif function_name == "find":
|
||||
action = response_function_web_search.ActionFind(
|
||||
type="find",
|
||||
pattern=parsed_args["pattern"],
|
||||
# TODO: translate to url
|
||||
url=f"cursor:{parsed_args.get('cursor', '')}",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown function name: {function_name}")
|
||||
|
||||
state.current_item_id = f"tool_{random_uuid()}"
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=response_function_web_search.ResponseFunctionWebSearch(
|
||||
# TODO: generate a unique id for web search call
|
||||
type="web_search_call",
|
||||
id=state.current_item_id,
|
||||
action=action,
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseWebSearchCallInProgressEvent(
|
||||
type="response.web_search_call.in_progress",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseWebSearchCallSearchingEvent(
|
||||
type="response.web_search_call.searching",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
# enqueue
|
||||
events.append(
|
||||
ResponseWebSearchCallCompletedEvent(
|
||||
type="response.web_search_call.completed",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseFunctionWebSearch(
|
||||
type="web_search_call",
|
||||
id=state.current_item_id,
|
||||
action=action,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_code_interpreter_completion_events(
|
||||
previous_item: HarmonyMessage,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when code interpreter completes."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallCodeDoneEvent(
|
||||
type="response.code_interpreter_call_code.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
code=previous_item.content[0].text,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallInterpretingEvent(
|
||||
type="response.code_interpreter_call.interpreting",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallCompletedEvent(
|
||||
type="response.code_interpreter_call.completed",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseCodeInterpreterToolCallParam(
|
||||
type="code_interpreter_call",
|
||||
id=state.current_item_id,
|
||||
code=previous_item.content[0].text,
|
||||
container_id="auto",
|
||||
outputs=[],
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_tool_action_events(
|
||||
ctx: StreamingHarmonyContext,
|
||||
state: StreamingState,
|
||||
tool_server: ToolServer | None,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for tool action turn."""
|
||||
if not ctx.is_assistant_action_turn() or len(ctx.parser.messages) == 0:
|
||||
return []
|
||||
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
previous_item = ctx.parser.messages[-1]
|
||||
|
||||
# Handle browser tool
|
||||
if (
|
||||
tool_server is not None
|
||||
and tool_server.has_tool("browser")
|
||||
and previous_item.recipient is not None
|
||||
and previous_item.recipient.startswith("browser.")
|
||||
):
|
||||
events.extend(emit_browser_tool_events(previous_item, state))
|
||||
|
||||
# Handle tool completion
|
||||
if (
|
||||
tool_server is not None
|
||||
and previous_item.recipient is not None
|
||||
and state.current_item_id is not None
|
||||
and state.sent_output_item_added
|
||||
):
|
||||
recipient = previous_item.recipient
|
||||
if recipient == "python":
|
||||
events.extend(emit_code_interpreter_completion_events(previous_item, state))
|
||||
elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
|
||||
events.extend(
|
||||
emit_mcp_completion_events(
|
||||
recipient, previous_item.content[0].text, state
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
263
vllm/entrypoints/openai/responses/utils.py
Normal file
263
vllm/entrypoints/openai/responses/utils.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as FunctionCallTool,
|
||||
)
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
|
||||
from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
from openai.types.responses.tool import Tool
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponseInputOutputItem
|
||||
|
||||
|
||||
def should_continue_final_message(
|
||||
request_input: str | list[ResponseInputOutputItem],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the last input message is a partial assistant message
|
||||
that should be continued rather than starting a new generation.
|
||||
|
||||
This enables partial message completion similar to Anthropic's Messages API,
|
||||
where users can provide an incomplete assistant message and have the model
|
||||
continue from where it left off.
|
||||
|
||||
A message is considered partial if:
|
||||
1. It's a ResponseOutputMessage or ResponseReasoningItem
|
||||
2. Its status is "in_progress" or "incomplete"
|
||||
|
||||
Args:
|
||||
request_input: The input to the Responses API request
|
||||
|
||||
Returns:
|
||||
True if the final message should be continued, False otherwise
|
||||
"""
|
||||
if isinstance(request_input, str):
|
||||
# Simple string input is always a user message
|
||||
return False
|
||||
|
||||
if not request_input:
|
||||
return False
|
||||
|
||||
last_item = request_input[-1]
|
||||
|
||||
# Check if the last item is a partial assistant message
|
||||
if isinstance(last_item, ResponseOutputMessage):
|
||||
return last_item.status in ("in_progress", "incomplete")
|
||||
|
||||
# Check if the last item is a partial reasoning item
|
||||
if isinstance(last_item, ResponseReasoningItem):
|
||||
return last_item.status in ("in_progress", "incomplete")
|
||||
|
||||
if isinstance(last_item, dict):
|
||||
# only support partial completion for messages for now
|
||||
if last_item.get("type", "message") not in ("message", "reasoning"):
|
||||
return False
|
||||
return last_item.get("status") in ("in_progress", "incomplete")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def construct_input_messages(
|
||||
*,
|
||||
request_instructions: str | None = None,
|
||||
request_input: str | list[ResponseInputOutputItem],
|
||||
prev_msg: list[ChatCompletionMessageParam] | None = None,
|
||||
prev_response_output: list[ResponseOutputItem] | None = None,
|
||||
):
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
if request_instructions:
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": request_instructions,
|
||||
}
|
||||
)
|
||||
|
||||
# Prepend the conversation history.
|
||||
if prev_msg is not None:
|
||||
# Add the previous messages.
|
||||
messages.extend(prev_msg)
|
||||
if prev_response_output is not None:
|
||||
# Add the previous output.
|
||||
for output_item in prev_response_output:
|
||||
# NOTE: We skip the reasoning output.
|
||||
if isinstance(output_item, ResponseOutputMessage):
|
||||
for content in output_item.content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content.text,
|
||||
}
|
||||
)
|
||||
|
||||
# Append the new input.
|
||||
# Responses API supports simple text inputs without chat format.
|
||||
if isinstance(request_input, str):
|
||||
messages.append({"role": "user", "content": request_input})
|
||||
else:
|
||||
input_messages = construct_chat_messages_with_tool_call(request_input)
|
||||
messages.extend(input_messages)
|
||||
return messages
|
||||
|
||||
|
||||
def _maybe_combine_reasoning_and_tool_call(
|
||||
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
|
||||
) -> ChatCompletionMessageParam | None:
|
||||
"""Many models treat MCP calls and reasoning as a single message.
|
||||
This function checks if the last message is a reasoning message and
|
||||
the current message is a tool call"""
|
||||
if not (
|
||||
isinstance(item, ResponseFunctionToolCall)
|
||||
and item.id
|
||||
and item.id.startswith(MCP_PREFIX)
|
||||
):
|
||||
return None
|
||||
if len(messages) == 0:
|
||||
return None
|
||||
last_message = messages[-1]
|
||||
if not (
|
||||
last_message.get("role") == "assistant"
|
||||
and last_message.get("reasoning") is not None
|
||||
):
|
||||
return None
|
||||
|
||||
last_message["tool_calls"] = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=item.call_id,
|
||||
function=FunctionCallTool(
|
||||
name=item.name,
|
||||
arguments=item.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
return last_message
|
||||
|
||||
|
||||
def construct_chat_messages_with_tool_call(
|
||||
input_messages: list[ResponseInputOutputItem],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""This function wraps _construct_single_message_from_response_item
|
||||
Because some chatMessages come from multiple response items
|
||||
for example a reasoning item and a MCP tool call are two response items
|
||||
but are one chat message
|
||||
"""
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
for item in input_messages:
|
||||
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
if maybe_combined_message is not None:
|
||||
messages[-1] = maybe_combined_message
|
||||
else:
|
||||
messages.append(_construct_single_message_from_response_item(item))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _construct_single_message_from_response_item(
|
||||
item: ResponseInputOutputItem,
|
||||
) -> ChatCompletionMessageParam:
|
||||
if isinstance(item, ResponseFunctionToolCall):
|
||||
# Append the function call as a tool call.
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=item.call_id,
|
||||
function=FunctionCallTool(
|
||||
name=item.name,
|
||||
arguments=item.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
)
|
||||
elif isinstance(item, ResponseReasoningItem):
|
||||
reasoning_content = ""
|
||||
if item.encrypted_content:
|
||||
raise ValueError("Encrypted content is not supported.")
|
||||
if len(item.summary) == 1:
|
||||
reasoning_content = item.summary[0].text
|
||||
elif item.content and len(item.content) == 1:
|
||||
reasoning_content = item.content[0].text
|
||||
return {
|
||||
"role": "assistant",
|
||||
"reasoning": reasoning_content,
|
||||
}
|
||||
elif isinstance(item, ResponseOutputMessage):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": item.content[0].text,
|
||||
}
|
||||
elif isinstance(item, ResponseFunctionToolCallOutputItem):
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=item.output,
|
||||
tool_call_id=item.call_id,
|
||||
)
|
||||
elif isinstance(item, dict) and item.get("type") == "function_call_output":
|
||||
# Append the function call output as a tool message.
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=item.get("output"),
|
||||
tool_call_id=item.get("call_id"),
|
||||
)
|
||||
return item # type: ignore
|
||||
|
||||
|
||||
def extract_tool_types(tools: list[Tool]) -> set[str]:
|
||||
"""
|
||||
Extracts the tool types from the given tools.
|
||||
"""
|
||||
tool_types: set[str] = set()
|
||||
for tool in tools:
|
||||
if tool.type == "mcp":
|
||||
# Allow the MCP Tool type to enable built in tools if the
|
||||
# server_label is allowlisted in
|
||||
# envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS
|
||||
if tool.server_label in envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
|
||||
tool_types.add(tool.server_label)
|
||||
else:
|
||||
tool_types.add(tool.type)
|
||||
return tool_types
|
||||
|
||||
|
||||
def convert_tool_responses_to_completions_format(tool: dict) -> dict:
|
||||
"""
|
||||
Convert a flat tool schema:
|
||||
{"type": "function", "name": "...", "description": "...", "parameters": {...}}
|
||||
into:
|
||||
{"type": "function", "function": {...}}
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": tool,
|
||||
}
|
||||
|
||||
|
||||
def construct_tool_dicts(
|
||||
tools: list[Tool], tool_choice: ToolChoice
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if tools is None or (tool_choice == "none"):
|
||||
tool_dicts = None
|
||||
else:
|
||||
tool_dicts = [
|
||||
convert_tool_responses_to_completions_format(tool.model_dump())
|
||||
for tool in tools
|
||||
]
|
||||
return tool_dicts
|
||||
843
vllm/entrypoints/openai/run_batch.py
Normal file
843
vllm/entrypoints/openai/run_batch.py
Normal file
@@ -0,0 +1,843 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any, TypeAlias
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from fastapi import UploadFile
|
||||
from prometheus_client import start_http_server
|
||||
from pydantic import Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from starlette.datastructures import State
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import config
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import init_app_state
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.cli_args import BaseFrontendArgs
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
OpenAIBaseModel,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BatchTranscriptionRequest(TranscriptionRequest):
|
||||
"""
|
||||
Batch transcription request that uses file_url instead of file.
|
||||
|
||||
This class extends TranscriptionRequest but replaces the file field
|
||||
with file_url to support batch processing from audio files written in JSON format.
|
||||
"""
|
||||
|
||||
file_url: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Either a URL of the audio or a data URL with base64 encoded audio data. "
|
||||
),
|
||||
)
|
||||
|
||||
# Override file to be optional and unused for batch processing
|
||||
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_no_file(cls, data: Any):
|
||||
"""Ensure file field is not provided in batch requests."""
|
||||
if isinstance(data, dict) and "file" in data:
|
||||
raise ValueError(
|
||||
"The 'file' field is not supported in batch requests. "
|
||||
"Use 'file_url' instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class BatchTranslationRequest(TranslationRequest):
|
||||
"""
|
||||
Batch translation request that uses file_url instead of file.
|
||||
|
||||
This class extends TranslationRequest but replaces the file field
|
||||
with file_url to support batch processing from audio files written in JSON format.
|
||||
"""
|
||||
|
||||
file_url: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Either a URL of the audio or a data URL with base64 encoded audio data. "
|
||||
),
|
||||
)
|
||||
|
||||
# Override file to be optional and unused for batch processing
|
||||
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_no_file(cls, data: Any):
|
||||
"""Ensure file field is not provided in batch requests."""
|
||||
if isinstance(data, dict) and "file" in data:
|
||||
raise ValueError(
|
||||
"The 'file' field is not supported in batch requests. "
|
||||
"Use 'file_url' instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
BatchRequestInputBody: TypeAlias = (
|
||||
ChatCompletionRequest
|
||||
| EmbeddingRequest
|
||||
| ScoreRequest
|
||||
| RerankRequest
|
||||
| BatchTranscriptionRequest
|
||||
| BatchTranslationRequest
|
||||
)
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
|
||||
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||
"""
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs. Must be unique for each request in a batch.
|
||||
custom_id: str
|
||||
|
||||
# The HTTP method to be used for the request. Currently only POST is
|
||||
# supported.
|
||||
method: str
|
||||
|
||||
# The OpenAI API relative URL to be used for the request. Currently
|
||||
# /v1/chat/completions is supported.
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: BatchRequestInputBody
|
||||
|
||||
@field_validator("body", mode="plain")
|
||||
@classmethod
|
||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||
# Use url to disambiguate models
|
||||
url: str = info.data["url"]
|
||||
if url == "/v1/chat/completions":
|
||||
return ChatCompletionRequest.model_validate(value)
|
||||
if url == "/v1/embeddings":
|
||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||
if url.endswith("/score"):
|
||||
return TypeAdapter(ScoreRequest).validate_python(value)
|
||||
if url.endswith("/rerank"):
|
||||
return RerankRequest.model_validate(value)
|
||||
if url == "/v1/audio/transcriptions":
|
||||
return BatchTranscriptionRequest.model_validate(value)
|
||||
if url == "/v1/audio/translations":
|
||||
return BatchTranslationRequest.model_validate(value)
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
status_code: int = 200
|
||||
|
||||
# An unique identifier for the API request.
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: (
|
||||
ChatCompletionResponse
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| None
|
||||
) = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch output and error files
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs.
|
||||
custom_id: str
|
||||
|
||||
response: BatchResponseData | None
|
||||
|
||||
# For requests that failed with a non-HTTP error, this will contain more
|
||||
# information on the cause of the failure.
|
||||
error: Any | None
|
||||
|
||||
|
||||
@config
|
||||
class BatchFrontendArgs(BaseFrontendArgs):
|
||||
"""Arguments for the batch runner frontend."""
|
||||
|
||||
input_file: str | None = None
|
||||
"""The path or url to a single input file. Currently supports local file
|
||||
paths, or the http protocol (http or https). If a URL is specified,
|
||||
the file should be available via HTTP GET."""
|
||||
output_file: str | None = None
|
||||
"""The path or url to a single output file. Currently supports
|
||||
local file paths, or web (http or https) urls. If a URL is specified,
|
||||
the file should be available via HTTP PUT."""
|
||||
output_tmp_dir: str | None = None
|
||||
"""The directory to store the output file before uploading it
|
||||
to the output URL."""
|
||||
enable_metrics: bool = False
|
||||
"""Enable Prometheus metrics"""
|
||||
host: str | None = None
|
||||
"""Host name for the Prometheus metrics server
|
||||
(only needed if enable-metrics is set)."""
|
||||
port: int = 8000
|
||||
"""Port number for the Prometheus metrics server
|
||||
(only needed if enable-metrics is set)."""
|
||||
url: str = "0.0.0.0"
|
||||
"""[DEPRECATED] Host name for the Prometheus metrics server
|
||||
(only needed if enable-metrics is set). Use --host instead."""
|
||||
|
||||
@classmethod
|
||||
def _customize_cli_kwargs(
|
||||
cls,
|
||||
frontend_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
frontend_kwargs = super()._customize_cli_kwargs(frontend_kwargs)
|
||||
|
||||
frontend_kwargs["input_file"]["flags"] = ["-i"]
|
||||
frontend_kwargs["input_file"]["required"] = True
|
||||
frontend_kwargs["output_file"]["flags"] = ["-o"]
|
||||
frontend_kwargs["output_file"]["required"] = True
|
||||
|
||||
frontend_kwargs["enable_metrics"]["action"] = "store_true"
|
||||
|
||||
frontend_kwargs["url"]["deprecated"] = True
|
||||
return frontend_kwargs
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser):
|
||||
parser = BatchFrontendArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
|
||||
args = make_arg_parser(parser).parse_args()
|
||||
|
||||
# Backward compatibility: If --url is set, use it for host
|
||||
url_explicit = any(arg == "--url" or arg.startswith("--url=") for arg in sys.argv)
|
||||
host_explicit = any(
|
||||
arg == "--host" or arg.startswith("--host=") for arg in sys.argv
|
||||
)
|
||||
if url_explicit and hasattr(args, "url") and not host_explicit:
|
||||
args.host = args.url
|
||||
logger.warning_once(
|
||||
"Using --url for metrics is deprecated. Please use --host instead."
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
class BatchProgressTracker:
|
||||
def __init__(self):
|
||||
self._total = 0
|
||||
self._pbar: tqdm | None = None
|
||||
|
||||
def submitted(self):
|
||||
self._total += 1
|
||||
|
||||
def completed(self):
|
||||
if self._pbar:
|
||||
self._pbar.update()
|
||||
|
||||
def pbar(self) -> tqdm:
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
self._pbar = tqdm(
|
||||
total=self._total,
|
||||
unit="req",
|
||||
desc="Running batch",
|
||||
mininterval=5,
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
return self._pbar
|
||||
|
||||
|
||||
async def read_file(path_or_url: str) -> str:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
with open(path_or_url, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def write_local_file(
|
||||
output_path: str, batch_outputs: list[BatchRequestOutput]
|
||||
) -> None:
|
||||
"""
|
||||
Write the responses to a local file.
|
||||
output_path: The path to write the responses to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
"""
|
||||
# We should make this async, but as long as run_batch runs as a
|
||||
# standalone program, blocking the event loop won't affect performance.
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=f)
|
||||
|
||||
|
||||
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
|
||||
"""
|
||||
Upload a local file to a URL.
|
||||
output_url: The URL to upload the file to.
|
||||
data_or_file: Either the data to upload or the path to the file to upload.
|
||||
from_file: If True, data_or_file is the path to the file to upload.
|
||||
"""
|
||||
# Timeout is a common issue when uploading large files.
|
||||
# We retry max_retries times before giving up.
|
||||
max_retries = 5
|
||||
# Number of seconds to wait before retrying.
|
||||
delay = 5
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# We increase the timeout to 1000 seconds to allow
|
||||
# for large files (default is 300).
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=1000)
|
||||
) as session:
|
||||
if from_file:
|
||||
with open(data_or_file, "rb") as file:
|
||||
async with session.put(output_url, data=file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Failed to upload file.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}"
|
||||
)
|
||||
else:
|
||||
async with session.put(output_url, data=data_or_file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Failed to upload data.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries:
|
||||
logger.error(
|
||||
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
|
||||
attempt,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
|
||||
) from e
|
||||
|
||||
|
||||
async def write_file(
|
||||
path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Write batch_outputs to a file or upload to a URL.
|
||||
path_or_url: The path or URL to write batch_outputs to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
output_tmp_dir: The directory to store the output file before uploading it
|
||||
to the output URL.
|
||||
"""
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
if output_tmp_dir is None:
|
||||
logger.info("Writing outputs to memory buffer")
|
||||
output_buffer = StringIO()
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=output_buffer)
|
||||
output_buffer.seek(0)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(
|
||||
path_or_url,
|
||||
output_buffer.read().strip().encode("utf-8"),
|
||||
from_file=False,
|
||||
)
|
||||
else:
|
||||
# Write responses to a temporary file and then upload it to the URL.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=output_tmp_dir,
|
||||
prefix="tmp_batch_output_",
|
||||
suffix=".jsonl",
|
||||
) as f:
|
||||
logger.info("Writing outputs to temporary local file %s", f.name)
|
||||
await write_local_file(f.name, batch_outputs)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(path_or_url, f.name, from_file=True)
|
||||
else:
|
||||
logger.info("Writing outputs to local file %s", path_or_url)
|
||||
await write_local_file(path_or_url, batch_outputs)
|
||||
|
||||
|
||||
async def download_bytes_from_url(url: str) -> bytes:
|
||||
"""
|
||||
Download data from a URL or decode from a data URL.
|
||||
|
||||
Args:
|
||||
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
|
||||
|
||||
Returns:
|
||||
Data as bytes
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Handle data URLs (base64 encoded)
|
||||
if parsed.scheme == "data":
|
||||
# Format: data:...;base64,<base64_data>
|
||||
if "," in url:
|
||||
header, data = url.split(",", 1)
|
||||
if "base64" in header:
|
||||
return base64.b64decode(data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data URL encoding: {header}")
|
||||
else:
|
||||
raise ValueError(f"Invalid data URL format: {url}")
|
||||
|
||||
# Handle HTTP/HTTPS URLs
|
||||
elif parsed.scheme in ("http", "https"):
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(url) as resp,
|
||||
):
|
||||
if resp.status != 200:
|
||||
raise Exception(
|
||||
f"Failed to download data from URL: {url}. Status: {resp.status}"
|
||||
)
|
||||
return await resp.read()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL scheme: {parsed.scheme}. "
|
||||
"Supported schemes: http, https, data"
|
||||
)
|
||||
|
||||
|
||||
def make_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str
|
||||
) -> BatchRequestOutput:
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=error_msg,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
|
||||
async def make_async_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str
|
||||
) -> BatchRequestOutput:
|
||||
return make_error_request_output(request, error_msg)
|
||||
|
||||
|
||||
async def run_request(
|
||||
serving_engine_func: Callable,
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker,
|
||||
) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
ScoreResponse,
|
||||
RerankResponse,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=response.error.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
batch_output = make_error_request_output(
|
||||
request, error_msg="Request must not be sent in stream mode"
|
||||
)
|
||||
|
||||
tracker.completed()
|
||||
return batch_output
|
||||
|
||||
|
||||
WrapperFn: TypeAlias = Callable[[Callable], Callable]
|
||||
|
||||
|
||||
def handle_endpoint_request(
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker,
|
||||
url_matcher: Callable[[str], bool],
|
||||
handler_getter: Callable[[], Callable | None],
|
||||
wrapper_fn: WrapperFn | None = None,
|
||||
) -> Awaitable[BatchRequestOutput] | None:
|
||||
"""
|
||||
Generic handler for endpoint requests.
|
||||
|
||||
Args:
|
||||
request: The batch request input
|
||||
tracker: Progress tracker for the batch
|
||||
url_matcher: Function that takes a URL and returns True if it matches
|
||||
handler_getter: Function that returns the handler function or None
|
||||
wrapper_fn: Optional function to wrap the handler (e.g., for transcriptions)
|
||||
|
||||
Returns:
|
||||
Awaitable[BatchRequestOutput] if the request was handled,
|
||||
None if URL didn't match
|
||||
"""
|
||||
if not url_matcher(request.url):
|
||||
return None
|
||||
|
||||
handler_fn = handler_getter()
|
||||
if handler_fn is None:
|
||||
error_msg = f"Model does not support endpoint: {request.url}"
|
||||
return make_async_error_request_output(request, error_msg=error_msg)
|
||||
|
||||
# Apply wrapper if provided (e.g., for transcriptions/translations)
|
||||
if wrapper_fn is not None:
|
||||
handler_fn = wrapper_fn(handler_fn)
|
||||
|
||||
tracker.submitted()
|
||||
return run_request(handler_fn, request, tracker)
|
||||
|
||||
|
||||
def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
|
||||
"""
|
||||
Factory function to create a wrapper for transcription/translation handlers.
|
||||
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
|
||||
to TranscriptionRequest or TranslationRequest and calls the appropriate handler.
|
||||
|
||||
Args:
|
||||
is_translation: If True, process as translation; otherwise process
|
||||
as transcription
|
||||
|
||||
Returns:
|
||||
A function that takes a handler and returns a wrapped handler
|
||||
"""
|
||||
|
||||
def wrapper(handler_fn: Callable):
|
||||
async def transcription_wrapper(
|
||||
batch_request_body: (BatchTranscriptionRequest | BatchTranslationRequest),
|
||||
) -> (
|
||||
TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| ErrorResponse
|
||||
):
|
||||
try:
|
||||
# Download data from URL
|
||||
audio_data = await download_bytes_from_url(batch_request_body.file_url)
|
||||
|
||||
# Create a mock file from the downloaded audio data
|
||||
mock_file = UploadFile(
|
||||
file=BytesIO(audio_data),
|
||||
filename="audio.bin",
|
||||
)
|
||||
|
||||
# Convert batch request to regular request
|
||||
# by copying all fields except file_url and setting file to mock_file
|
||||
request_dict = batch_request_body.model_dump(exclude={"file_url"})
|
||||
request_dict["file"] = mock_file
|
||||
|
||||
if is_translation:
|
||||
# Create TranslationRequest from BatchTranslationRequest
|
||||
translation_request = TranslationRequest.model_validate(
|
||||
request_dict
|
||||
)
|
||||
return await handler_fn(audio_data, translation_request)
|
||||
else:
|
||||
# Create TranscriptionRequest from BatchTranscriptionRequest
|
||||
transcription_request = TranscriptionRequest.model_validate(
|
||||
request_dict
|
||||
)
|
||||
return await handler_fn(audio_data, transcription_request)
|
||||
except Exception as e:
|
||||
operation = "translation" if is_translation else "transcription"
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=f"Failed to process {operation}: {str(e)}",
|
||||
type="BadRequestError",
|
||||
code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
)
|
||||
|
||||
return transcription_wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def build_endpoint_registry(
|
||||
engine_client: EngineClient,
|
||||
args: Namespace,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Build the endpoint registry with all serving objects and handler configurations.
|
||||
|
||||
Args:
|
||||
engine_client: The engine client
|
||||
args: Command line arguments
|
||||
|
||||
Returns:
|
||||
Dictionary mapping endpoint keys to their configurations
|
||||
"""
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
# Create a state object to hold serving objects
|
||||
state = State()
|
||||
|
||||
# Initialize all serving objects using init_app_state
|
||||
# This provides full functionality including chat template processing,
|
||||
# LoRA support, tool servers, etc.
|
||||
await init_app_state(engine_client, state, args, supported_tasks)
|
||||
|
||||
# Get serving objects from state (defaulting to None if not set)
|
||||
openai_serving_chat = getattr(state, "openai_serving_chat", None)
|
||||
openai_serving_embedding = getattr(state, "openai_serving_embedding", None)
|
||||
openai_serving_scores = getattr(state, "openai_serving_scores", None)
|
||||
openai_serving_transcription = getattr(state, "openai_serving_transcription", None)
|
||||
openai_serving_translation = getattr(state, "openai_serving_translation", None)
|
||||
|
||||
# Registry of endpoint configurations
|
||||
endpoint_registry: dict[str, dict[str, Any]] = {
|
||||
"completions": {
|
||||
"url_matcher": lambda url: url == "/v1/chat/completions",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_chat.create_chat_completion
|
||||
if openai_serving_chat is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"embeddings": {
|
||||
"url_matcher": lambda url: url == "/v1/embeddings",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_embedding.create_embedding
|
||||
if openai_serving_embedding is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"score": {
|
||||
"url_matcher": lambda url: url.endswith("/score"),
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_scores.create_score
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"rerank": {
|
||||
"url_matcher": lambda url: url.endswith("/rerank"),
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_scores.do_rerank
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"transcriptions": {
|
||||
"url_matcher": lambda url: url == "/v1/audio/transcriptions",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_transcription.create_transcription
|
||||
if openai_serving_transcription is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": make_transcription_wrapper(is_translation=False),
|
||||
},
|
||||
"translations": {
|
||||
"url_matcher": lambda url: url == "/v1/audio/translations",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_translation.create_translation
|
||||
if openai_serving_translation is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": make_transcription_wrapper(is_translation=True),
|
||||
},
|
||||
}
|
||||
|
||||
return endpoint_registry
|
||||
|
||||
|
||||
def validate_run_batch_args(args):
|
||||
valid_reasoning_parsers = ReasoningParserManager.list_registered()
|
||||
if (
|
||||
reasoning_parser := args.structured_outputs_config.reasoning_parser
|
||||
) and reasoning_parser not in valid_reasoning_parsers:
|
||||
raise KeyError(
|
||||
f"invalid reasoning parser: {reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
|
||||
)
|
||||
|
||||
|
||||
async def run_batch(
|
||||
engine_client: EngineClient,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
endpoint_registry = await build_endpoint_registry(
|
||||
engine_client=engine_client,
|
||||
args=args,
|
||||
)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: list[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Use the last segment of the URL as the endpoint key.
|
||||
# More advanced URL matching is done in url_matcher of endpoint_registry.
|
||||
endpoint_key = request.url.split("/")[-1]
|
||||
|
||||
result = None
|
||||
if endpoint_key in endpoint_registry:
|
||||
endpoint_config = endpoint_registry[endpoint_key]
|
||||
result = handle_endpoint_request(
|
||||
request,
|
||||
tracker,
|
||||
url_matcher=endpoint_config["url_matcher"],
|
||||
handler_getter=endpoint_config["handler_getter"],
|
||||
wrapper_fn=endpoint_config["wrapper_fn"],
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
response_futures.append(result)
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /v1/audio/transcriptions, /v1/audio/translations, /score, "
|
||||
" /rerank. See vllm/entrypoints/openai/api_server.py "
|
||||
"for supported score/rerank versions.",
|
||||
)
|
||||
)
|
||||
|
||||
with tracker.pbar():
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
await write_file(args.output_file, responses, args.output_tmp_dir)
|
||||
|
||||
|
||||
async def main(args: Namespace):
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
validate_run_batch_args(args)
|
||||
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
|
||||
disable_frontend_multiprocessing=False,
|
||||
) as engine_client:
|
||||
await run_batch(engine_client, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.host)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(main(args))
|
||||
382
vllm/entrypoints/openai/server_utils.py
Normal file
382
vllm/entrypoints/openai/server_utils.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
|
||||
import pydantic
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
|
||||
from vllm.entrypoints.utils import sanitize_message
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
|
||||
logger = init_logger("vllm.entrypoints.openai.server_utils")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are two cases in which authentication is skipped:
|
||||
1. The HTTP method is OPTIONS.
|
||||
2. The request path doesn't start with /v1 (e.g. /health).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
|
||||
# scope["type"] can be "lifespan" or "startup" for example,
|
||||
# in which case we don't need to do anything
|
||||
return self.app(scope, receive, send)
|
||||
root_path = scope.get("root_path", "")
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
||||
return response(scope, receive, send)
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
|
||||
class XRequestIdMiddleware:
|
||||
"""
|
||||
Middleware the set's the X-Request-Id header for each response
|
||||
to a random uuid4 (hex) value if the header isn't already
|
||||
present in the request, otherwise use the provided request id.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Extract the request headers.
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
async def send_with_request_id(message: Message) -> None:
|
||||
"""
|
||||
Custom send function to mutate the response headers
|
||||
and append X-Request-Id to it.
|
||||
"""
|
||||
if message["type"] == "http.response.start":
|
||||
response_headers = MutableHeaders(raw=message["headers"])
|
||||
request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
|
||||
response_headers.append("X-Request-Id", request_id)
|
||||
await send(message)
|
||||
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
try:
|
||||
with open(log_config_file) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to load log config from file %s: error %s", log_config_file, e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_uvicorn_log_config(args: Namespace) -> dict | None:
|
||||
"""
|
||||
Get the uvicorn log config based on the provided arguments.
|
||||
|
||||
Priority:
|
||||
1. If log_config_file is specified, use it
|
||||
2. If disable_access_log_for_endpoints is specified, create a config with
|
||||
the access log filter
|
||||
3. Otherwise, return None (use uvicorn defaults)
|
||||
"""
|
||||
# First, try to load from file if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
return log_config
|
||||
|
||||
# If endpoints to filter are specified, create a config with the filter
|
||||
if args.disable_access_log_for_endpoints:
|
||||
from vllm.logging_utils import create_uvicorn_log_config
|
||||
|
||||
# Parse comma-separated string into list
|
||||
excluded_paths = [
|
||||
p.strip()
|
||||
for p in args.disable_access_log_for_endpoints.split(",")
|
||||
if p.strip()
|
||||
]
|
||||
return create_uvicorn_log_config(
|
||||
excluded_paths=excluded_paths,
|
||||
log_level=args.uvicorn_log_level,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||
"""Extract content from a streaming response chunk."""
|
||||
try:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.protocol import (
|
||||
CompletionStreamResponse,
|
||||
)
|
||||
|
||||
# Try using Completion types for type-safe parsing
|
||||
if chunk_data.get("object") == "chat.completion.chunk":
|
||||
chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
|
||||
if chat_response.choices and chat_response.choices[0].delta.content:
|
||||
return chat_response.choices[0].delta.content
|
||||
elif chunk_data.get("object") == "text_completion":
|
||||
completion_response = CompletionStreamResponse.model_validate(chunk_data)
|
||||
if completion_response.choices and completion_response.choices[0].text:
|
||||
return completion_response.choices[0].text
|
||||
except pydantic.ValidationError:
|
||||
# Fallback to manual parsing
|
||||
if "choices" in chunk_data and chunk_data["choices"]:
|
||||
choice = chunk_data["choices"][0]
|
||||
if "delta" in choice and choice["delta"].get("content"):
|
||||
return choice["delta"]["content"]
|
||||
elif choice.get("text"):
|
||||
return choice["text"]
|
||||
return ""
|
||||
|
||||
|
||||
class SSEDecoder:
|
||||
"""Robust Server-Sent Events decoder for streaming responses."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
self.content_buffer = []
|
||||
|
||||
def decode_chunk(self, chunk: bytes) -> list[dict]:
|
||||
"""Decode a chunk of SSE data and return parsed events."""
|
||||
import json
|
||||
|
||||
try:
|
||||
chunk_str = chunk.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Skip malformed chunks
|
||||
return []
|
||||
|
||||
self.buffer += chunk_str
|
||||
events = []
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in self.buffer:
|
||||
line, self.buffer = self.buffer.split("\n", 1)
|
||||
line = line.rstrip("\r") # Handle CRLF
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
events.append({"type": "done"})
|
||||
elif data_str:
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
events.append({"type": "data", "data": event_data})
|
||||
except json.JSONDecodeError:
|
||||
# Skip malformed JSON
|
||||
continue
|
||||
|
||||
return events
|
||||
|
||||
def extract_content(self, event_data: dict) -> str:
|
||||
"""Extract content from event data."""
|
||||
return _extract_content_from_chunk(event_data)
|
||||
|
||||
def add_content(self, content: str) -> None:
|
||||
"""Add content to the buffer."""
|
||||
if content:
|
||||
self.content_buffer.append(content)
|
||||
|
||||
def get_complete_content(self) -> str:
|
||||
"""Get the complete buffered content."""
|
||||
return "".join(self.content_buffer)
|
||||
|
||||
|
||||
def _log_streaming_response(response, response_body: list) -> None:
|
||||
"""Log streaming response with robust SSE parsing."""
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
sse_decoder = SSEDecoder()
|
||||
chunk_count = 0
|
||||
|
||||
def buffered_iterator():
|
||||
nonlocal chunk_count
|
||||
|
||||
for chunk in response_body:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
|
||||
# Parse SSE events from chunk
|
||||
events = sse_decoder.decode_chunk(chunk)
|
||||
|
||||
for event in events:
|
||||
if event["type"] == "data":
|
||||
content = sse_decoder.extract_content(event["data"])
|
||||
sse_decoder.add_content(content)
|
||||
elif event["type"] == "done":
|
||||
# Log complete content when done
|
||||
full_content = sse_decoder.get_complete_content()
|
||||
if full_content:
|
||||
# Truncate if too long
|
||||
if len(full_content) > 2048:
|
||||
full_content = full_content[:2048] + ""
|
||||
"...[truncated]"
|
||||
logger.info(
|
||||
"response_body={streaming_complete: content=%r, chunks=%d}",
|
||||
full_content,
|
||||
chunk_count,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"response_body={streaming_complete: no_content, chunks=%d}",
|
||||
chunk_count,
|
||||
)
|
||||
return
|
||||
|
||||
response.body_iterator = iterate_in_threadpool(buffered_iterator())
|
||||
logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
|
||||
|
||||
|
||||
def _log_non_streaming_response(response_body: list) -> None:
|
||||
"""Log non-streaming response."""
|
||||
try:
|
||||
decoded_body = response_body[0].decode()
|
||||
logger.info("response_body={%s}", decoded_body)
|
||||
except UnicodeDecodeError:
|
||||
logger.info("response_body={<binary_data>}")
|
||||
|
||||
|
||||
async def log_response(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
response_body = [section async for section in response.body_iterator]
|
||||
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||
# Check if this is a streaming response by looking at content-type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
is_streaming = content_type == "text/event-stream; charset=utf-8"
|
||||
|
||||
# Log response body based on type
|
||||
if not response_body:
|
||||
logger.info("response_body={<empty>}")
|
||||
elif is_streaming:
|
||||
_log_streaming_response(response, response_body)
|
||||
else:
|
||||
_log_non_streaming_response(response_body)
|
||||
return response
|
||||
|
||||
|
||||
async def http_exception_handler(_: Request, exc: HTTPException):
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(exc.detail),
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
|
||||
async def validation_exception_handler(_: Request, exc: RequestValidationError):
|
||||
param = None
|
||||
errors = exc.errors()
|
||||
for error in errors:
|
||||
if "ctx" in error and "error" in error["ctx"]:
|
||||
ctx_error = error["ctx"]["error"]
|
||||
if isinstance(ctx_error, VLLMValidationError):
|
||||
param = ctx_error.parameter
|
||||
break
|
||||
|
||||
exc_str = str(exc)
|
||||
errors_str = str(errors)
|
||||
|
||||
if errors and errors_str and errors_str != exc_str:
|
||||
message = f"{exc_str} {errors_str}"
|
||||
else:
|
||||
message = exc_str
|
||||
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(message),
|
||||
type=HTTPStatus.BAD_REQUEST.phrase,
|
||||
code=HTTPStatus.BAD_REQUEST,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
_running_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
freeze_gc_heap()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
2
vllm/entrypoints/openai/speech_to_text/__init__.py
Normal file
2
vllm/entrypoints/openai/speech_to_text/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
159
vllm/entrypoints/openai/speech_to_text/api_router.py
Normal file
159
vllm/entrypoints/openai/speech_to_text/api_router.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponseVariant,
|
||||
TranslationRequest,
|
||||
TranslationResponseVariant,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.serving import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
def translation(request: Request) -> OpenAIServingTranslation:
|
||||
return request.app.state.openai_serving_translation
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_transcriptions(
|
||||
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
|
||||
):
|
||||
handler = transcription(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Transcriptions API"
|
||||
)
|
||||
|
||||
audio_data = await request.file.read()
|
||||
try:
|
||||
generator = await handler.create_transcription(audio_data, request, raw_request)
|
||||
except Exception as e:
|
||||
return handler.create_error_response(e)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranscriptionResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/translations",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_translations(
|
||||
request: Annotated[TranslationRequest, Form()], raw_request: Request
|
||||
):
|
||||
handler = translation(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Translations API"
|
||||
)
|
||||
|
||||
audio_data = await request.file.read()
|
||||
try:
|
||||
generator = await handler.create_translation(audio_data, request, raw_request)
|
||||
except Exception as e:
|
||||
return handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranslationResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
def init_transcription_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
state.openai_serving_transcription = (
|
||||
OpenAIServingTranscription(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_translation = (
|
||||
OpenAIServingTranslation(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
545
vllm/entrypoints/openai/speech_to_text/protocol.py
Normal file
545
vllm/entrypoints/openai/speech_to_text/protocol.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
OpenAIBaseModel,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
|
||||
class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
|
||||
object: Literal["transcription.chunk"] = "transcription.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranscriptionResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
## Protocols for Audio
|
||||
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
|
||||
|
||||
|
||||
class TranscriptionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to transcribe, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str | None = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
language: str | None = None
|
||||
"""The language of the input audio.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy and latency.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
||||
|
||||
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
||||
alias="timestamp_granularities[]", default=[]
|
||||
)
|
||||
"""The timestamp granularities to populate for this transcription.
|
||||
|
||||
`response_format` must be set `verbose_json` to use timestamp granularities.
|
||||
Either or both of these options are supported: `word`, or `segment`. Note:
|
||||
There is no additional latency for segment timestamps, but generating word
|
||||
timestamps incurs additional latency.
|
||||
"""
|
||||
|
||||
stream: bool | None = False
|
||||
"""When set, it will enable output to be streamed in a similar fashion
|
||||
as the Chat Completion endpoint.
|
||||
"""
|
||||
# --8<-- [start:transcription-extra-params]
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: bool | None = False
|
||||
stream_continuous_usage_stats: bool | None = False
|
||||
|
||||
vllm_xargs: dict[str, str | int | float] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:transcription-extra-params]
|
||||
|
||||
to_language: str | None = None
|
||||
"""The language of the output audio we transcribe to.
|
||||
|
||||
Please note that this is not currently used by supported models at this
|
||||
time, but it is a placeholder for future use, matching translation api.
|
||||
"""
|
||||
|
||||
# --8<-- [start:transcription-sampling-params]
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
|
||||
top_p: float | None = None
|
||||
"""Enables nucleus (top-p) sampling, where tokens are selected from the
|
||||
smallest possible set whose cumulative probability exceeds `p`.
|
||||
"""
|
||||
|
||||
top_k: int | None = None
|
||||
"""Limits sampling to the `k` most probable tokens at each step."""
|
||||
|
||||
min_p: float | None = None
|
||||
"""Filters out tokens with a probability lower than `min_p`, ensuring a
|
||||
minimum likelihood threshold during sampling.
|
||||
"""
|
||||
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
frequency_penalty: float | None = 0.0
|
||||
"""The frequency penalty to use for sampling."""
|
||||
|
||||
repetition_penalty: float | None = None
|
||||
"""The repetition penalty to use for sampling."""
|
||||
|
||||
presence_penalty: float | None = 0.0
|
||||
"""The presence penalty to use for sampling."""
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:transcription-sampling-params]
|
||||
|
||||
# Default sampling parameters for transcription requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 0,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||
)
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||
)
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
|
||||
)
|
||||
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
seed=self.seed,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
presence_penalty=self.presence_penalty,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
extra_args=self.vllm_xargs,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_transcription_request(cls, data):
|
||||
if isinstance(data.get("file"), str):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
||||
detail="Expected 'file' to be a file-like object, not 'str'.",
|
||||
)
|
||||
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Transcription response objects
|
||||
class TranscriptionUsageAudio(OpenAIBaseModel):
|
||||
type: Literal["duration"] = "duration"
|
||||
seconds: int
|
||||
|
||||
|
||||
class TranscriptionResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The transcribed text."""
|
||||
usage: TranscriptionUsageAudio
|
||||
|
||||
|
||||
class TranscriptionWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranscriptionSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The transcribed text."""
|
||||
|
||||
segments: list[TranscriptionSegment] | None = None
|
||||
"""Segments of the transcribed text and their corresponding details."""
|
||||
|
||||
words: list[TranscriptionWord] | None = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranscriptionResponseVariant: TypeAlias = (
|
||||
TranscriptionResponse | TranscriptionResponseVerbose
|
||||
)
|
||||
|
||||
|
||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
|
||||
class TranslationStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||
object: Literal["translation.chunk"] = "translation.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranslationResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
class TranslationRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to translate, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str | None = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
# TODO support additional sampling parameters
|
||||
# --8<-- [start:translation-sampling-params]
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
# --8<-- [end:translation-sampling-params]
|
||||
|
||||
# --8<-- [start:translation-extra-params]
|
||||
language: str | None = None
|
||||
"""The language of the input audio we translate from.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy.
|
||||
"""
|
||||
|
||||
to_language: str | None = None
|
||||
"""The language of the input audio we translate to.
|
||||
|
||||
Please note that this is not supported by all models, refer to the specific
|
||||
model documentation for more details.
|
||||
For instance, Whisper only supports `to_language=en`.
|
||||
"""
|
||||
|
||||
stream: bool | None = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: bool | None = False
|
||||
stream_continuous_usage_stats: bool | None = False
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:translation-extra-params]
|
||||
|
||||
# Default sampling parameters for translation requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
seed=self.seed,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Translation response objects
|
||||
class TranslationResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
|
||||
class TranslationWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranslationSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
segments: list[TranslationSegment] | None = None
|
||||
"""Segments of the translated text and their corresponding details."""
|
||||
|
||||
words: list[TranslationWord] | None = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
||||
176
vllm/entrypoints/openai/speech_to_text/serving.py
Normal file
176
vllm/entrypoints/openai/speech_to_text/serving.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.speech_to_text import OpenAISpeechToText
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
"""Handles transcription requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="transcribe",
|
||||
log_error_stack=log_error_stack,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_transcription(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranscriptionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=(
|
||||
TranscriptionResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranscriptionResponse
|
||||
),
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="transcription.chunk",
|
||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||
stream_response_class=TranscriptionStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate",
|
||||
log_error_stack=log_error_stack,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_translation(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranslationRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=(
|
||||
TranslationResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranslationResponse
|
||||
),
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self,
|
||||
request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
770
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
Normal file
770
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
Normal file
@@ -0,0 +1,770 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
import zlib
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import cached_property
|
||||
from typing import Final, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionSegment,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import (
|
||||
SupportsTranscription,
|
||||
supports_transcription,
|
||||
)
|
||||
from vllm.multimodal.audio import split_audio
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||
SpeechToTextResponseVerbose: TypeAlias = (
|
||||
TranscriptionResponseVerbose | TranslationResponseVerbose
|
||||
)
|
||||
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
|
||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
|
||||
S = TypeVar("S", bound=SpeechToTextSegment)
|
||||
|
||||
ResponseType: TypeAlias = (
|
||||
TranscriptionResponse
|
||||
| TranslationResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponseVerbose
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
task_type: Literal["transcribe", "translate"] = "transcribe",
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
self.task_type: Final = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
self.model_config, task_type
|
||||
)
|
||||
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
if self.model_cls.supports_segment_timestamp:
|
||||
self.tokenizer = cast(
|
||||
PreTrainedTokenizerBase,
|
||||
get_tokenizer(
|
||||
tokenizer_name=self.model_config.tokenizer,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
),
|
||||
)
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
# Warm up audio preprocessing to avoid first-request latency
|
||||
self._warmup_audio_preprocessing()
|
||||
# Warm up input processor with dummy audio
|
||||
self._warmup_input_processor()
|
||||
|
||||
def _warmup_audio_preprocessing(self) -> None:
|
||||
"""Warm up audio processing libraries to avoid first-request latency.
|
||||
|
||||
The first call to librosa functions (load, get_duration, mel-spectrogram)
|
||||
triggers JIT compilation and library initialization which can take ~7s.
|
||||
This method warms up these operations during server initialization.
|
||||
"""
|
||||
# Skip warmup if librosa is not installed (optional dependency)
|
||||
if isinstance(librosa, PlaceholderModule):
|
||||
return
|
||||
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
if getattr(self.model_cls, "skip_warmup_audio_preprocessing", False):
|
||||
return
|
||||
|
||||
try:
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up audio preprocessing libraries...")
|
||||
|
||||
# Create a minimal dummy audio (1 second of silence at target sample rate)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Warm up librosa.load by using librosa functions on the dummy data
|
||||
# This initializes FFTW, numba JIT, and other audio processing libraries
|
||||
_ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)
|
||||
|
||||
# Warm up mel-spectrogram computation with model-specific parameters
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
|
||||
processor = cached_processor_from_config(self.model_config)
|
||||
feature_extractor = None
|
||||
if hasattr(processor, "feature_extractor"):
|
||||
feature_extractor = processor.feature_extractor
|
||||
elif hasattr(processor, "audio_processor"):
|
||||
# For models like GraniteSpeech that use audio_processor
|
||||
audio_proc = processor.audio_processor
|
||||
if hasattr(audio_proc, "feature_extractor"):
|
||||
feature_extractor = audio_proc.feature_extractor
|
||||
# If audio_processor doesn't have feature_extractor,
|
||||
# skip mel-spectrogram warmup for these models
|
||||
|
||||
if feature_extractor is not None:
|
||||
_ = librosa.feature.melspectrogram(
|
||||
y=dummy_audio,
|
||||
sr=self.asr_config.sample_rate,
|
||||
n_mels=getattr(feature_extractor, "n_mels", 128),
|
||||
n_fft=getattr(feature_extractor, "n_fft", 400),
|
||||
hop_length=getattr(feature_extractor, "hop_length", 160),
|
||||
)
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log exception and continue
|
||||
logger.exception(
|
||||
"Audio preprocessing warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency.",
|
||||
)
|
||||
|
||||
def _warmup_input_processor(self) -> None:
|
||||
"""Warm up input processor with dummy audio to avoid first-request latency.
|
||||
|
||||
The first call to renderer.render_cmpl() with multimodal audio
|
||||
triggers multimodal processing initialization which can take ~2.5s.
|
||||
This method processes a dummy audio request to warm up the pipeline.
|
||||
"""
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
# Only warm up if model supports transcription methods
|
||||
if not hasattr(self.model_cls, "get_generation_prompt"):
|
||||
return
|
||||
|
||||
try:
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up multimodal input processor...")
|
||||
|
||||
# Create minimal dummy audio (1 second of silence)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Use the same method that _preprocess_speech_to_text uses
|
||||
# to create the prompt
|
||||
dummy_prompt = self.model_cls.get_generation_prompt(
|
||||
audio=dummy_audio,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language="en",
|
||||
task_type=self.task_type,
|
||||
request_prompt="",
|
||||
to_language=None,
|
||||
)
|
||||
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
|
||||
|
||||
# Process the dummy input through the input processor
|
||||
# This will trigger all the multimodal processing initialization
|
||||
_ = self.renderer.render_cmpl([parsed_prompt])
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log warning and continue
|
||||
logger.exception(
|
||||
"Input processor warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency."
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsTranscription], model_cls)
|
||||
|
||||
async def _detect_language(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
request_id: str,
|
||||
) -> str:
|
||||
"""Auto-detect the spoken language from an audio chunk.
|
||||
|
||||
Delegates prompt construction and output parsing to the model class
|
||||
via ``get_language_detection_prompt`` and
|
||||
``parse_language_detection_output``.
|
||||
"""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
prompt = self.model_cls.get_language_detection_prompt(
|
||||
audio_chunk,
|
||||
self.asr_config,
|
||||
)
|
||||
allowed_token_ids = self.model_cls.get_language_token_ids(
|
||||
self.tokenizer,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
)
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
)
|
||||
|
||||
final_output: RequestOutput
|
||||
async for final_output in result_generator:
|
||||
if final_output.finished:
|
||||
break
|
||||
|
||||
token_ids = list(final_output.outputs[0].token_ids)
|
||||
lang = self.model_cls.parse_language_detection_output(
|
||||
token_ids,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
logger.info("Auto-detected language: '%s'", lang)
|
||||
return lang
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
request_id: str,
|
||||
) -> tuple[list[ProcessorInputs], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
to_language = (
|
||||
self.model_cls.validate_language(request.to_language)
|
||||
if request.to_language
|
||||
else None
|
||||
)
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise VLLMValidationError(
|
||||
"Maximum file size exceeded",
|
||||
parameter="audio_filesize_mb",
|
||||
value=len(audio_data) / 1024**2,
|
||||
)
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (
|
||||
self.asr_config.allow_audio_chunking
|
||||
and duration > self.asr_config.max_audio_clip_s
|
||||
)
|
||||
|
||||
if not do_split_audio:
|
||||
chunks = [y]
|
||||
else:
|
||||
assert self.asr_config.max_audio_clip_s is not None
|
||||
assert self.asr_config.min_energy_split_window_size is not None
|
||||
chunks = split_audio(
|
||||
audio_data=y,
|
||||
sample_rate=int(sr),
|
||||
max_clip_duration_s=self.asr_config.max_audio_clip_s,
|
||||
overlap_duration_s=self.asr_config.overlap_chunk_second,
|
||||
min_energy_window_size=self.asr_config.min_energy_split_window_size,
|
||||
)
|
||||
|
||||
if language is None and getattr(
|
||||
self.model_cls, "supports_explicit_language_detection", False
|
||||
):
|
||||
# Auto-detect language from the first chunk.
|
||||
language = await self._detect_language(
|
||||
chunks[0], f"{request_id}-lang_detect"
|
||||
)
|
||||
request.language = language
|
||||
|
||||
parsed_prompts: list[DictPrompt] = []
|
||||
for chunk in chunks:
|
||||
# The model has control over the construction, as long as it
|
||||
# returns a valid PromptType.
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
|
||||
parsed_prompt: DictPrompt
|
||||
if request.response_format == "verbose_json":
|
||||
parsed_prompt = parse_enc_dec_prompt(prompt)
|
||||
parsed_prompt = self._preprocess_verbose_prompt(parsed_prompt)
|
||||
else:
|
||||
parsed_prompt = parse_model_prompt(self.model_config, prompt)
|
||||
|
||||
parsed_prompts.append(parsed_prompt)
|
||||
|
||||
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
|
||||
|
||||
return engine_prompts, duration
|
||||
|
||||
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
|
||||
raise VLLMValidationError(
|
||||
"Expected decoder_prompt to contain text",
|
||||
parameter="decoder_prompt",
|
||||
value=type(dec_prompt).__name__,
|
||||
)
|
||||
|
||||
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
|
||||
"<|notimestamps|>", "<|0.00|>"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def _get_verbose_segments(
|
||||
self,
|
||||
tokens: tuple,
|
||||
log_probs: FlatLogprobs | list[dict[int, Logprob]],
|
||||
request: SpeechToTextRequest,
|
||||
segment_class: type[SpeechToTextSegment],
|
||||
start_time: float = 0,
|
||||
) -> list[SpeechToTextSegment]:
|
||||
"""
|
||||
Convert tokens to verbose segments.
|
||||
|
||||
This method expects the model to produce
|
||||
timestamps as tokens (similar to Whisper).
|
||||
If the tokens do not include timestamp information,
|
||||
the segments may not be generated correctly.
|
||||
|
||||
Note: No_speech_prob field is not supported
|
||||
in this implementation and will be None. See docs for details.
|
||||
"""
|
||||
BASE_OFFSET = 0.02
|
||||
init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
|
||||
if tokens[-1] == self.tokenizer.eos_token_id:
|
||||
tokens = tokens[:-1]
|
||||
|
||||
tokens_with_start = (init_token,) + tokens
|
||||
segments: list[SpeechToTextSegment] = []
|
||||
last_timestamp_start = 0
|
||||
|
||||
if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
|
||||
tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
|
||||
avg_logprob = 0.0
|
||||
for idx in range(1, len(tokens_with_start)):
|
||||
# Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
|
||||
# If the ordering is violated, this slicing may produce incorrect results.
|
||||
token = tokens_with_start[idx]
|
||||
if token >= init_token and tokens_with_start[idx - 1] >= init_token:
|
||||
sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
|
||||
start_timestamp = sliced_timestamp_tokens[0] - init_token
|
||||
end_timestamp = sliced_timestamp_tokens[-1] - init_token
|
||||
text = self.tokenizer.decode(sliced_timestamp_tokens[1:-1])
|
||||
text_bytes = text.encode("utf-8")
|
||||
|
||||
casting_segment = cast(
|
||||
SpeechToTextSegment,
|
||||
segment_class(
|
||||
id=len(segments),
|
||||
seek=start_time,
|
||||
start=start_time + BASE_OFFSET * start_timestamp,
|
||||
end=start_time + BASE_OFFSET * end_timestamp,
|
||||
temperature=request.temperature,
|
||||
text=text,
|
||||
# The compression ratio measures
|
||||
# how compressible the generated text is.
|
||||
# A higher ratio indicates more repetitive content,
|
||||
# which is a strong sign of hallucination in outputs.
|
||||
compression_ratio=len(text_bytes)
|
||||
/ len(zlib.compress(text_bytes)),
|
||||
tokens=sliced_timestamp_tokens[1:-1],
|
||||
avg_logprob=avg_logprob / (idx - last_timestamp_start),
|
||||
),
|
||||
)
|
||||
segments.append(casting_segment)
|
||||
last_timestamp_start = idx
|
||||
avg_logprob = 0
|
||||
else:
|
||||
avg_logprob += log_probs[idx - 1][token].logprob
|
||||
return segments
|
||||
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: SpeechToTextRequest,
|
||||
raw_request: Request,
|
||||
response_class: type[ResponseType],
|
||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ["text", "json", "verbose_json"]:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format: "
|
||||
"`text`, `json` or `verbose_json`"
|
||||
)
|
||||
|
||||
if (
|
||||
request.response_format == "verbose_json"
|
||||
and not self.model_cls.supports_segment_timestamp
|
||||
):
|
||||
return self.create_error_response(
|
||||
f"Currently do not support verbose_json for {request.model}"
|
||||
)
|
||||
|
||||
if request.response_format == "verbose_json" and request.stream:
|
||||
return self.create_error_response(
|
||||
"verbose_json format doesn't support streaming case"
|
||||
)
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
engine_prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
|
||||
# generated by respecting the extra completion tokens arg.
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_completion_tokens,
|
||||
0,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
if request.response_format == "verbose_json":
|
||||
sampling_params.logprobs = 1
|
||||
|
||||
list_result_generator = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}_{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
list_result_generator.append(generator)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(
|
||||
request, list_result_generator, request_id, request_metadata, duration_s
|
||||
)
|
||||
# Non-streaming response.
|
||||
total_segments = []
|
||||
text_parts = []
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
segments_types: dict[str, type[SpeechToTextSegment]] = {
|
||||
"transcribe": TranscriptionSegment,
|
||||
"translate": TranslationSegment,
|
||||
}
|
||||
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
||||
text = ""
|
||||
chunk_size_in_s = self.asr_config.max_audio_clip_s
|
||||
if chunk_size_in_s is None:
|
||||
assert len(list_result_generator) == 1, (
|
||||
"`max_audio_clip_s` is set to None, audio cannot be chunked"
|
||||
)
|
||||
for idx, result_generator in enumerate(list_result_generator):
|
||||
start_time = (
|
||||
float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
|
||||
)
|
||||
async for op in result_generator:
|
||||
if request.response_format == "verbose_json":
|
||||
assert op.outputs[0].logprobs
|
||||
segments: list[SpeechToTextSegment] = (
|
||||
self._get_verbose_segments(
|
||||
tokens=tuple(op.outputs[0].token_ids),
|
||||
segment_class=segment_class,
|
||||
request=request,
|
||||
start_time=start_time,
|
||||
log_probs=op.outputs[0].logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
total_segments.extend(segments)
|
||||
text_parts.extend([seg.text for seg in segments])
|
||||
else:
|
||||
raw_text = op.outputs[0].text
|
||||
text_parts.append(self.model_cls.post_process_output(raw_text))
|
||||
text = "".join(text_parts)
|
||||
if self.task_type == "transcribe":
|
||||
final_response: ResponseType
|
||||
# add usage in TranscriptionResponse.
|
||||
usage = {
|
||||
"type": "duration",
|
||||
# rounded up as per openAI specs
|
||||
"seconds": int(math.ceil(duration_s)),
|
||||
}
|
||||
if request.response_format != "verbose_json":
|
||||
final_response = cast(
|
||||
T, TranscriptionResponse(text=text, usage=usage)
|
||||
)
|
||||
else:
|
||||
final_response = cast(
|
||||
V,
|
||||
TranscriptionResponseVerbose(
|
||||
text=text,
|
||||
language=request.language,
|
||||
duration=str(duration_s),
|
||||
segments=total_segments,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# no usage in response for translation task
|
||||
if request.response_format != "verbose_json":
|
||||
final_response = cast(T, TranslationResponse(text=text))
|
||||
else:
|
||||
final_response = cast(
|
||||
V,
|
||||
TranslationResponseVerbose(
|
||||
text=text,
|
||||
language=request.language,
|
||||
duration=str(duration_s),
|
||||
segments=total_segments,
|
||||
),
|
||||
)
|
||||
return final_response
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
|
||||
response_stream_choice_class: type[TranscriptionResponseStreamChoice]
|
||||
| type[TranslationResponseStreamChoice],
|
||||
stream_response_class: type[TranscriptionStreamResponse]
|
||||
| type[TranslationStreamResponse],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = self.enable_force_include_usage or request.stream_include_usage
|
||||
include_continuous_usage = (
|
||||
request.stream_continuous_usage_stats
|
||||
if include_usage and request.stream_continuous_usage_stats
|
||||
else False
|
||||
)
|
||||
|
||||
try:
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if audio_tokens := self.model_cls.get_num_audio_tokens(
|
||||
audio_duration_s, self.asr_config, self.model_config
|
||||
):
|
||||
num_prompt_tokens += audio_tokens
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
# TODO: For models that output structured formats (e.g.,
|
||||
# Qwen3-ASR with "language X<asr_text>" prefix), streaming
|
||||
# would need buffering to strip the prefix properly since
|
||||
# deltas may split the tag across chunks.
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = response_stream_choice_class(delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
|
||||
chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True
|
||||
)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
12
vllm/entrypoints/openai/translations/__init__.py
Normal file
12
vllm/entrypoints/openai/translations/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"The 'vllm.entrypoints.openai.translations' module has been renamed to "
|
||||
"'vllm.entrypoints.openai.speech_to_text'. Please update your imports. "
|
||||
"This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
14
vllm/entrypoints/openai/translations/api_router.py
Normal file
14
vllm/entrypoints/openai/translations/api_router.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.api_router' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.api_router'. Please update your "
|
||||
"imports. This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.api_router import * # noqa: F401,F403,E402
|
||||
14
vllm/entrypoints/openai/translations/protocol.py
Normal file
14
vllm/entrypoints/openai/translations/protocol.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.protocol' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.protocol'. Please update your "
|
||||
"imports. This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import * # noqa: F401,F403,E402
|
||||
14
vllm/entrypoints/openai/translations/serving.py
Normal file
14
vllm/entrypoints/openai/translations/serving.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.serving' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.serving'. Please update your "
|
||||
"imports. This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.serving import * # noqa: F401,F403,E402
|
||||
15
vllm/entrypoints/openai/translations/speech_to_text.py
Normal file
15
vllm/entrypoints/openai/translations/speech_to_text.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.speech_to_text' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.speech_to_text'. Please update "
|
||||
"your imports. This backward-compatible alias will be removed in version "
|
||||
"0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.speech_to_text import * # noqa: F401,F403,E402
|
||||
49
vllm/entrypoints/openai/utils.py
Normal file
49
vllm/entrypoints/openai/utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
)
|
||||
|
||||
# Used internally
|
||||
_ChatCompletionResponseChoiceT = TypeVar(
|
||||
"_ChatCompletionResponseChoiceT",
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
)
|
||||
|
||||
|
||||
def maybe_filter_parallel_tool_calls(
|
||||
choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
|
||||
) -> _ChatCompletionResponseChoiceT:
|
||||
"""Filter to first tool call only when parallel_tool_calls is False."""
|
||||
|
||||
if request.parallel_tool_calls:
|
||||
return choice
|
||||
|
||||
if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
|
||||
choice.message.tool_calls = choice.message.tool_calls[:1]
|
||||
elif (
|
||||
isinstance(choice, ChatCompletionResponseStreamChoice)
|
||||
and choice.delta.tool_calls
|
||||
):
|
||||
choice.delta.tool_calls = [
|
||||
tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
|
||||
]
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
async def validate_json_request(raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
media_type = content_type.split(";", maxsplit=1)[0]
|
||||
if media_type != "application/json":
|
||||
raise RequestValidationError(
|
||||
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
|
||||
)
|
||||
Reference in New Issue
Block a user