first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import asyncio
import json
import ssl
from argparse import Namespace
from collections.abc import AsyncGenerator
from typing import Any, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
app = FastAPI()
engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\n").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
def build_app(args: Namespace) -> FastAPI:
global app
app.root_path = args.root_path
return app
async def init_app(
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER))
app.state.engine_client = engine
return app
async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
set_ulimit()
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=parser.check_port, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.throughput import (
BenchmarkThroughputSubcommand)
__all__: list[str] = [
"BenchmarkLatencySubcommand",
"BenchmarkServingSubcommand",
"BenchmarkThroughputSubcommand",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkSubcommandBase(CLISubcommand):
""" The base class of subcommands for vllm bench. """
help: str
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
"""Add the CLI arguments to the parser."""
raise NotImplementedError
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Run the benchmark.
Args:
args: The arguments to the command.
"""
raise NotImplementedError

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.latency import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
""" The `latency` subcommand for vllm bench. """
name = "latency"
help = "Benchmark the latency of a single batch of requests."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class BenchmarkSubcommand(CLISubcommand):
""" The `bench` subcommand for the vLLM CLI. """
name = "bench"
help = "vLLM bench subcommand."
@staticmethod
def cmd(args: argparse.Namespace) -> None:
args.dispatch_function(args)
def validate(self, args: argparse.Namespace) -> None:
pass
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
bench_parser = subparsers.add_parser(
self.name,
description=self.help,
usage=f"vllm {self.name} <bench_type> [options]")
bench_subparsers = bench_parser.add_subparsers(required=True,
dest="bench_type")
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
cmd_subparser = bench_subparsers.add_parser(
cmd_cls.name,
help=cmd_cls.help,
description=cmd_cls.help,
usage=f"vllm {self.name} {cmd_cls.name} [options]",
)
cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd)
cmd_cls.add_cli_args(cmd_subparser)
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=f"{self.name} {cmd_cls.name}")
return bench_parser
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkSubcommand()]

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.serve import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
""" The `serve` subcommand for vllm bench. """
name = "serve"
help = "Benchmark the online serving throughput."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.throughput import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
""" The `throughput` subcommand for vllm bench. """
name = "throughput"
help = "Benchmark offline inference throughput."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import typing
from vllm.collect_env import main as collect_env_main
from vllm.entrypoints.cli.types import CLISubcommand
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class CollectEnvSubcommand(CLISubcommand):
"""The `collect-env` subcommand for the vLLM CLI. """
name = "collect-env"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Collect information about the environment."""
collect_env_main()
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
return subparsers.add_parser(
"collect-env",
help="Start collecting environment information.",
description="Start collecting environment information.",
usage="vllm collect-env")
def cmd_init() -> list[CLISubcommand]:
return [CollectEnvSubcommand()]

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
'''The CLI entrypoints of vLLM
Note that all future modules must be lazily loaded within main
to avoid certain eager import breakage.'''
from __future__ import annotations
import importlib.metadata
def main():
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
import vllm.entrypoints.cli.serve
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
from vllm.utils import FlexibleArgumentParser
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
vllm.entrypoints.cli.run_batch,
]
cli_env_setup()
parser = FlexibleArgumentParser(
description="vLLM CLI",
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
)
parser.add_argument(
'-v',
'--version',
action='version',
version=importlib.metadata.version('vllm'),
)
subparsers = parser.add_subparsers(required=False, dest="subparser")
cmds = {}
for cmd_module in CMD_MODULES:
new_cmds = cmd_module.cmd_init()
for cmd in new_cmds:
cmd.subparser_init(subparsers).set_defaults(
dispatch_function=cmd.cmd)
cmds[cmd.name] = cmd
args = parser.parse_args()
if args.subparser in cmds:
cmds[args.subparser].validate(args)
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
else:
parser.print_help()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import os
import signal
import sys
from typing import TYPE_CHECKING
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.cli.types import CLISubcommand
if TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
def _register_signal_handlers():
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTSTP, signal_handler)
def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
_register_signal_handlers()
base_url = args.url
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
openai_client = OpenAI(api_key=api_key, base_url=base_url)
if args.model_name:
model_name = args.model_name
else:
available_models = openai_client.models.list()
model_name = available_models.data[0].id
print(f"Using model: {model_name}")
return model_name, openai_client
def _print_chat_stream(stream) -> str:
output = ""
for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
output += delta.content
print(delta.content, end="", flush=True)
print()
return output
def _print_completion_stream(stream) -> str:
output = ""
for chunk in stream:
text = chunk.choices[0].text
if text is not None:
output += text
print(text, end="", flush=True)
print()
return output
def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:")
while True:
try:
input_message = input("> ")
except EOFError:
break
conversation.append({"role": "user", "content": input_message})
stream = client.chat.completions.create(model=model_name,
messages=conversation,
stream=True)
output = _print_chat_stream(stream)
conversation.append({"role": "assistant", "content": output})
def _add_query_options(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--url",
type=str,
default="http://localhost:8000/v1",
help="url of the running OpenAI-Compatible RESTful API server")
parser.add_argument(
"--model-name",
type=str,
default=None,
help=("The model name used in prompt completion, default to "
"the first model in list models API call."))
parser.add_argument(
"--api-key",
type=str,
default=None,
help=(
"API key for OpenAI services. If provided, this api key "
"will overwrite the api key obtained through environment variables."
))
return parser
class ChatCommand(CLISubcommand):
"""The `chat` subcommand for the vLLM CLI. """
name = "chat"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
model_name, client = _interactive_cli(args)
system_prompt = args.system_prompt
conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
if args.quick:
conversation.append({"role": "user", "content": args.quick})
stream = client.chat.completions.create(model=model_name,
messages=conversation,
stream=True)
output = _print_chat_stream(stream)
conversation.append({"role": "assistant", "content": output})
return
print("Please enter a message for the chat model:")
while True:
try:
input_message = input("> ")
except EOFError:
break
conversation.append({"role": "user", "content": input_message})
stream = client.chat.completions.create(model=model_name,
messages=conversation,
stream=True)
output = _print_chat_stream(stream)
conversation.append({"role": "assistant", "content": output})
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Add CLI arguments for the chat command."""
_add_query_options(parser)
parser.add_argument(
"--system-prompt",
type=str,
default=None,
help=("The system prompt to be added to the chat template, "
"used for models that support system prompts."))
parser.add_argument("-q",
"--quick",
type=str,
metavar="MESSAGE",
help=("Send a single prompt as MESSAGE "
"and print the response, then exit."))
return parser
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
parser = subparsers.add_parser(
"chat",
help="Generate chat completions via the running API server.",
description="Generate chat completions via the running API server.",
usage="vllm chat [options]")
return ChatCommand.add_cli_args(parser)
class CompleteCommand(CLISubcommand):
"""The `complete` subcommand for the vLLM CLI. """
name = 'complete'
@staticmethod
def cmd(args: argparse.Namespace) -> None:
model_name, client = _interactive_cli(args)
if args.quick:
stream = client.completions.create(model=model_name,
prompt=args.quick,
stream=True)
_print_completion_stream(stream)
return
print("Please enter prompt to complete:")
while True:
try:
input_prompt = input("> ")
except EOFError:
break
stream = client.completions.create(model=model_name,
prompt=input_prompt,
stream=True)
_print_completion_stream(stream)
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Add CLI arguments for the complete command."""
_add_query_options(parser)
parser.add_argument(
"-q",
"--quick",
type=str,
metavar="PROMPT",
help=
"Send a single prompt and print the completion output, then exit.")
return parser
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
parser = subparsers.add_parser(
"complete",
help=("Generate text completions based on the given prompt "
"via the running API server."),
description=("Generate text completions based on the given prompt "
"via the running API server."),
usage="vllm complete [options]")
return CompleteCommand.add_cli_args(parser)
def cmd_init() -> list[CLISubcommand]:
return [ChatCommand(), CompleteCommand()]

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import asyncio
import importlib.metadata
import typing
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from vllm.logger import init_logger
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
class RunBatchSubcommand(CLISubcommand):
"""The `run-batch` subcommand for vLLM CLI."""
name = "run-batch"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
from vllm.entrypoints.openai.run_batch import main as run_batch_main
logger.info("vLLM batch processing API version %s",
importlib.metadata.version("vllm"))
logger.info("args: %s", args)
# Start the Prometheus metrics server.
# LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
from prometheus_client import start_http_server
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(run_batch_main(args))
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
from vllm.entrypoints.openai.run_batch import make_arg_parser
run_batch_parser = subparsers.add_parser(
self.name,
help="Run batch prompts and write results to file.",
description=(
"Run batch prompts using vLLM's OpenAI-compatible API.\n"
"Supports local or HTTP input/output files."),
usage=
"vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
)
run_batch_parser = make_arg_parser(run_batch_parser)
run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=self.name)
return run_batch_parser
def cmd_init() -> list[CLISubcommand]:
return [RunBatchSubcommand()]

View File

@@ -0,0 +1,232 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import signal
from typing import Optional
import uvloop
import vllm
import vllm.envs as envs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
setup_server)
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, decorate_logs, get_tcp_uri,
set_process_title)
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (APIServerProcessManager,
wait_for_completion_or_failure)
logger = init_logger(__name__)
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.
Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
Use `--help=all` to show all available flags at once.
"""
class ServeSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI. """
name = "serve"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# If model is specified in CLI (as positional arg), it takes precedence
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag
if args.headless or args.api_server_count < 1:
run_headless(args)
else:
if args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
self.name,
description=DESCRIPTION,
usage="vllm serve [model_tag] [options]")
serve_parser = make_arg_parser(serve_parser)
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=self.name)
return serve_parser
def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]
def run_headless(args: argparse.Namespace):
if args.api_server_count > 1:
raise ValueError("api_server_count can't be set in headless mode")
# Create the EngineConfig.
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context,
headless=True)
if not envs.VLLM_USE_V1:
raise ValueError("Headless mode is only supported for V1")
if engine_args.data_parallel_hybrid_lb:
raise ValueError("data_parallel_hybrid_lb is not applicable in "
"headless mode")
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
if local_engine_count <= 0:
raise ValueError("data_parallel_size_local must be > 0 in "
"headless mode")
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
handshake_address = get_tcp_uri(host, port)
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
logger.debug("Received %d signal.", signum)
raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, handshake_address)
# Create the engines.
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=vllm_config.parallel_config.data_parallel_rank,
local_start_index=0,
vllm_config=vllm_config,
local_client=False,
handshake_address=handshake_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
try:
engine_manager.join_first()
finally:
logger.info("Shutting down.")
engine_manager.close()
def run_multi_api_server(args: argparse.Namespace):
assert not args.headless
num_api_servers: int = args.api_server_count
assert num_api_servers > 0
if num_api_servers > 1:
setup_multiprocess_prometheus()
listen_address, sock = setup_server(args)
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
engine_args._api_process_count = num_api_servers
engine_args._api_process_rank = -1
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if num_api_servers > 1:
if not envs.VLLM_USE_V1:
raise ValueError("api_server_count > 1 is only supported for V1")
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1")
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats
parallel_config = vllm_config.parallel_config
dp_rank = parallel_config.data_parallel_rank
external_dp_lb = parallel_config.data_parallel_external_lb
hybrid_dp_lb = parallel_config.data_parallel_hybrid_lb
assert external_dp_lb or hybrid_dp_lb or dp_rank == 0
api_server_manager: Optional[APIServerProcessManager] = None
with launch_core_engines(vllm_config, executor_class, log_stats,
num_api_servers) as (local_engine_manager,
coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
target_server_fn=run_api_server_worker_proc,
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=addresses.inputs,
output_addresses=addresses.outputs,
stats_update_address=coordinator.get_stats_publish_address()
if coordinator else None)
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
# start of the API servers until the local engine is started
# (after the launcher context manager exits),
# since we get the front-end stats update address from the coordinator
# via the handshake with the local engine.
if dp_rank == 0 or not (external_dp_lb or hybrid_dp_lb):
# Start API servers using the manager.
api_server_manager = APIServerProcessManager(
**api_server_manager_kwargs)
# Start API servers now if they weren't already started.
if api_server_manager is None:
api_server_manager_kwargs["stats_update_address"] = (
addresses.frontend_stats_publish_address)
api_server_manager = APIServerProcessManager(
**api_server_manager_kwargs)
# Wait for API servers
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator)
def run_api_server_worker_proc(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:
"""Entrypoint for individual API server worker processes."""
client_config = client_config or {}
server_index = client_config.get("client_index", 0)
# Set process title and add process-specific prefix to stdout and stderr.
set_process_title("APIServer", str(server_index))
decorate_logs()
uvloop.run(
run_server_worker(listen_address, sock, args, client_config,
**uvicorn_kwargs))

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import typing
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class CLISubcommand:
"""Base class for CLI argument handlers."""
name: str
@staticmethod
def cmd(args: argparse.Namespace) -> None:
raise NotImplementedError("Subclasses should implement this method")
def validate(self, args: argparse.Namespace) -> None:
# No validation by default
pass
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
raise NotImplementedError("Subclasses should implement this method")

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Shared constants for vLLM entrypoints.
"""
# HTTP header limits for h11 parser
# These constants help mitigate header abuse attacks
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
H11_MAX_HEADER_COUNT_DEFAULT = 256

481
vllm/entrypoints/context.py Normal file
View File

@@ -0,0 +1,481 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import json
import logging
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import (
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
if TYPE_CHECKING:
from mcp.client import ClientSession
logger = logging.getLogger(__name__)
# This is currently needed as the tool type doesn't 1:1 match the
# tool namespace, which is what is used to look up the
# connection to the tool server
_TOOL_NAME_TO_TYPE_MAP = {
"browser": "web_search_preview",
"python": "code_interpreter",
"container": "container",
}
def _map_tool_name_to_tool_type(tool_name: str) -> str:
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
raise ValueError(
f"Built-in tool name '{tool_name}' not defined in mapping. "
f"Available tools: {available_tools}")
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
class TurnTokens:
"""Tracks token counts for a single conversation turn."""
def __init__(self, input_tokens=0, output_tokens=0):
self.input_tokens = input_tokens
self.output_tokens = output_tokens
def reset(self):
"""Reset counters for a new turn."""
self.input_tokens = 0
self.output_tokens = 0
def copy(self):
"""Create a copy of this turn's token counts."""
return TurnTokens(self.input_tokens, self.output_tokens)
class ConversationContext(ABC):
@abstractmethod
def append_output(self, output) -> None:
pass
@abstractmethod
async def call_tool(self) -> list[Message]:
pass
@abstractmethod
def need_builtin_tool_call(self) -> bool:
pass
@abstractmethod
def render_for_completion(self) -> list[int]:
pass
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass
@abstractmethod
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class SimpleContext(ConversationContext):
def __init__(self):
self.last_output = None
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
# todo num_reasoning_tokens is not implemented yet.
self.num_reasoning_tokens = 0
def append_output(self, output) -> None:
self.last_output = output
if not isinstance(output, RequestOutput):
raise ValueError("SimpleContext only supports RequestOutput.")
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
def need_builtin_tool_call(self) -> bool:
return False
async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
available_tools: list[str],
):
self._messages = messages
self.finish_reason: Optional[str] = None
self.available_tools = available_tools
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
self.called_tools: set[str] = set()
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0
self.num_tool_output_tokens = 0
# Turn tracking - replaces multiple individual tracking variables
self.current_turn = TurnTokens()
self.previous_turn = TurnTokens()
self.is_first_turn = True
self.first_tok_of_message = True # For streaming support
def _update_num_reasoning_tokens(self):
# Count all analysis and commentary channels as reasoning tokens
if self.parser.current_channel in {"analysis", "commentary"}:
self.num_reasoning_tokens += 1
def append_output(self, output: Union[RequestOutput,
list[Message]]) -> None:
if isinstance(output, RequestOutput):
output_token_ids = output.outputs[0].token_ids
self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output)
# Reset current turn output tokens for this turn
self.current_turn.output_tokens = 0
self._update_decode_token_usage(output)
# Move current turn to previous turn for next turn's calculations
self.previous_turn = self.current_turn.copy()
# append_output is called only once before tool calling
# in non-streaming case
# so we can append all the parser messages to _messages
output_msgs = self.parser.messages
# The responses finish reason is set in the last message
self.finish_reason = output.outputs[0].finish_reason
else:
# Tool output.
output_msgs = output
self._messages.extend(output_msgs)
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
"""Update token usage statistics for the prefill phase of generation.
The prefill phase processes the input prompt tokens. This method:
1. Counts the prompt tokens for this turn
2. Calculates tool output tokens for multi-turn conversations
3. Updates cached token counts
4. Tracks state for next turn calculations
Tool output tokens are calculated as:
current_prompt_tokens - last_turn_prompt_tokens -
last_turn_output_tokens
This represents tokens added between turns (typically tool responses).
Args:
output: The RequestOutput containing prompt token information
"""
if output.prompt_token_ids is not None:
this_turn_input_tokens = len(output.prompt_token_ids)
else:
this_turn_input_tokens = 0
logger.error(
"RequestOutput appended contains no prompt_token_ids.")
# Update current turn input tokens
self.current_turn.input_tokens = this_turn_input_tokens
self.num_prompt_tokens += this_turn_input_tokens
# Calculate tool tokens (except on first turn)
if self.is_first_turn:
self.is_first_turn = False
else:
# start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill -
# last turn decode
this_turn_tool_tokens = (self.current_turn.input_tokens -
self.previous_turn.input_tokens -
self.previous_turn.output_tokens)
# Handle negative tool token counts (shouldn't happen in normal
# cases)
if this_turn_tool_tokens < 0:
logger.error(
"Negative tool output tokens calculated: %d "
"(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0.",
this_turn_tool_tokens, self.current_turn.input_tokens,
self.previous_turn.input_tokens,
self.previous_turn.output_tokens)
this_turn_tool_tokens = 0
self.num_tool_output_tokens += this_turn_tool_tokens
# Update cached tokens
if output.num_cached_tokens is not None:
self.num_cached_tokens += output.num_cached_tokens
def _update_decode_token_usage(self, output: RequestOutput) -> int:
"""Update token usage statistics for the decode phase of generation.
The decode phase processes the generated output tokens. This method:
1. Counts output tokens from all completion outputs
2. Updates the total output token count
3. Tracks tokens generated in the current turn
In streaming mode, this is called for each token generated.
In non-streaming mode, this is called once with all output tokens.
Args:
output: The RequestOutput containing generated token information
Returns:
int: Number of output tokens processed in this call
"""
updated_output_token_count = 0
if output.outputs:
for completion_output in output.outputs:
# only keep last round
updated_output_token_count += len(completion_output.token_ids)
self.num_output_tokens += updated_output_token_count
self.current_turn.output_tokens += updated_output_token_count
return updated_output_token_count
@property
def messages(self) -> list:
return self._messages
def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
return recipient is not None and (recipient.startswith("browser.")
or recipient.startswith("python") or
recipient.startswith("container."))
async def call_tool(self) -> list[Message]:
if not self.messages:
return []
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self._tool_sessions["browser"], last_msg)
elif recipient.startswith("python"):
return await self.call_python_tool(
self._tool_sessions["python"], last_msg)
elif recipient.startswith("container."):
return await self.call_container_tool(
self._tool_sessions["container"], last_msg)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)
async def call_search_tool(self, tool_session: Union["ClientSession",
Tool],
last_msg: Message) -> list[Message]:
self.called_tools.add("browser")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel)
]
async def call_python_tool(self, tool_session: Union["ClientSession",
Tool],
last_msg: Message) -> list[Message]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
"code": last_msg.content[0].text,
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name="python")
return [
Message(author=author,
content=[content],
channel=last_msg.channel,
recipient=Role.ASSISTANT)
]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = mcp_tools[
tool_type].headers if tool_type in mcp_tools else None
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id,
headers))
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def call_container_tool(self, tool_session: Union["ClientSession",
Tool],
last_msg: Message) -> list[Message]:
"""
Call container tool. Expect this to be run in a stateful docker
with command line terminal.
The official container tool would at least
expect the following format:
- for tool name: exec
- args:
{
"cmd":List[str] "command to execute",
"workdir":optional[str] "current working directory",
"env":optional[object/dict] "environment variables",
"session_name":optional[str] "session name",
"timeout":optional[int] "timeout in seconds",
"user":optional[str] "user name",
}
"""
self.called_tools.add("container")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel)
]
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info("Cleaning up tool session for %s",
tool_session._client_info)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools))
class StreamingHarmonyContext(HarmonyContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_output = None
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
self.first_tok_of_message = True
@property
def messages(self) -> list:
return self._messages
def append_output(self, output: Union[RequestOutput,
list[Message]]) -> None:
if isinstance(output, RequestOutput):
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message:
self._update_prefill_token_usage(output)
self.current_turn.output_tokens = 0
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
# beginning of a new message
self.first_tok_of_message = output.finished
for tok in output.outputs[0].token_ids:
self.parser.process(tok)
self._update_decode_token_usage(output)
# For streaming, update previous turn when message is complete
if output.finished:
self.previous_turn = self.current_turn.copy()
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self.last_tok = tok
if len(self._messages) - self.num_init_messages < len(
self.parser.messages):
self._messages.extend(
self.parser.messages[len(self._messages) -
self.num_init_messages:])
else:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
msg = output[0]
# Sometimes the recipient is not set for tool messages,
# so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant"
toks = self.encoding.render(msg)
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
# TODO: add tool_output messages to self._messages
def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START
def is_assistant_action_turn(self) -> bool:
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions(
)
def render_for_completion(self) -> list[int]:
# now this list of tokens as next turn's starting tokens
# `<|start|>assistant``,
# we need to process them in parser.
rendered_tokens = super().render_for_completion()
last_n = -1
to_process = []
while rendered_tokens[last_n] != self.last_tok:
to_process.append(rendered_tokens[last_n])
last_n -= 1
for tok in reversed(to_process):
self.parser.process(tok)
return rendered_tokens

View File

@@ -0,0 +1,436 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import datetime
import json
from collections.abc import Iterable, Sequence
from typing import Literal, Optional, Union
from openai.types.responses import (ResponseFunctionToolCall,
ResponseOutputItem, ResponseOutputMessage,
ResponseOutputText, ResponseReasoningItem)
from openai.types.responses.response_function_web_search import (
ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch)
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent)
from openai.types.responses.tool import Tool
from openai_harmony import (Author, ChannelConfig, Conversation,
DeveloperContent, HarmonyEncodingName, Message,
ReasoningEffort, Role, StreamableParser,
SystemContent, TextContent, ToolDescription,
load_harmony_encoding)
from vllm import envs
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
ResponseInputOutputItem)
from vllm.utils import random_uuid
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
_harmony_encoding = None
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
BUILTIN_TOOLS = {
"web_search_preview",
"code_interpreter",
"container",
}
def has_custom_tools(tool_types: list[str]) -> bool:
return not set(tool_types).issubset(BUILTIN_TOOLS)
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(
HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def get_system_message(
model_identity: Optional[str] = None,
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
start_date: Optional[str] = None,
browser_description: Optional[str] = None,
python_description: Optional[str] = None,
container_description: Optional[str] = None,
instructions: Optional[str] = None,
with_custom_tools: bool = False,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if (instructions is not None
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
current_identity = sys_msg_content.model_identity
new_identity = (f'{current_identity}\n{instructions}'
if current_identity else instructions)
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort])
if start_date is None:
# NOTE(woosuk): This brings non-determinism in vLLM. Be careful.
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
if container_description is not None:
sys_msg_content = sys_msg_content.with_tools(container_description)
if not with_custom_tools:
channel_config = sys_msg_content.channel_config
invalid_channel = "commentary"
new_config = ChannelConfig.require_channels(
[c for c in channel_config.valid_channels if c != invalid_channel])
sys_msg_content = sys_msg_content.with_channel_config(new_config)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
if isinstance(tool, ChatCompletionToolsParam):
return ToolDescription.new(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
)
return ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
def get_developer_message(
instructions: Optional[str] = None,
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
) -> Message:
dev_msg_content = DeveloperContent.new()
if (instructions is not None
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter",
"container", "mcp"):
# These are built-in tools that are added to the system message.
# Adding in MCP for now until we support MCP tools executed
# server side
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
create_tool_definition(tool) for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def parse_response_input(
response_msg: ResponseInputOutputItem,
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]]
) -> Message:
if not isinstance(response_msg, dict):
response_msg = response_msg.model_dump()
if "type" not in response_msg or response_msg["type"] == "message":
role = response_msg["role"]
content = response_msg["content"]
if role == "system":
# User is trying to set a system message. Change it to:
# <|start|>developer<|message|># Instructions
# {instructions}<|end|>
role = "developer"
text_prefix = "Instructions:\n"
else:
text_prefix = ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
else:
contents = [
TextContent(text=text_prefix + c["text"]) for c in content
]
msg = Message.from_role_and_contents(role, contents)
if role == "assistant":
msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None
for prev_response in reversed(prev_responses):
if isinstance(prev_response, ResponseFunctionToolCall
) and prev_response.call_id == call_id:
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
response_msg["output"])
elif response_msg["type"] == "reasoning":
content = response_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif response_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT,
response_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{response_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {response_msg['type']}")
return msg
def parse_chat_input(chat_msg) -> list[Message]:
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls")
if role == "assistant" and tool_calls:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"),
content).with_channel("commentary")
return [msg]
# Default: user/assistant/system messages with content
content = chat_msg.get("content", "")
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
msg = Message.from_role_and_contents(role, contents)
return [msg]
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT)
return token_ids
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
"""
Parse a Harmony message into a list of output response items.
"""
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items: list[ResponseOutputItem] = []
recipient = message.recipient
if recipient is not None and recipient.startswith("browser."):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
browser_call = json.loads(content.text)
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search")
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page")
elif recipient == "browser.find":
action = ActionFind(pattern=browser_call["pattern"],
url=f"cursor:{browser_call.get('url', '')}",
type="find")
else:
raise ValueError(f"Unknown browser action: {recipient}")
web_search_item = ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
output_items.append(web_search_item)
elif message.channel == "analysis":
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(text=content.text,
type="reasoning_text")
],
status=None,
)
output_items.append(reasoning_item)
elif message.channel == "commentary":
if recipient is not None and recipient.startswith("functions."):
function_name = recipient.split(".")[-1]
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
elif recipient is not None and (recipient.startswith("python")
or recipient.startswith("browser")
or recipient.startswith("container")):
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(text=content.text,
type="reasoning_text")
],
status=None,
)
output_items.append(reasoning_item)
else:
raise ValueError(f"Unknown recipient: {recipient}")
elif message.channel == "final":
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
output_items.append(text_item)
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items
def parse_remaining_state(
parser: StreamableParser) -> list[ResponseOutputItem]:
if not parser.current_content:
return []
if parser.current_role != Role.ASSISTANT:
return []
current_recipient = parser.current_recipient
if (current_recipient is not None
and current_recipient.startswith("browser.")):
return []
if parser.current_channel == "analysis":
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(text=parser.current_content,
type="reasoning_text")
],
status=None,
)
return [reasoning_item]
elif parser.current_channel == "final":
output_text = ResponseOutputText(
text=parser.current_content,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
# if the parser still has messages (ie if the generator got cut
# abruptly), this should be incomplete
status="incomplete",
type="message",
)
return [text_item]
return []
def get_stop_tokens_for_assistant_actions() -> list[int]:
return get_encoding().stop_tokens_for_assistant_actions()
def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
parser = get_streamable_parser_for_assistant()
for token_id in token_ids:
parser.process(token_id)
return parser
def parse_chat_output(
token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]:
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
reasoning_content = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
reasoning_content = output_msgs[0].content[0].text
final_content = parser.current_content
else:
reasoning_msg = output_msgs[:-1]
final_msg = output_msgs[-1]
reasoning_content = "\n".join(
[msg.content[0].text for msg in reasoning_msg])
final_content = final_msg.content[0].text
return reasoning_content, final_content, is_tool_call

View File

@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import signal
import socket
from http import HTTPStatus
from typing import Any, Optional
import uvicorn
from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__)
async def serve_http(app: FastAPI,
sock: Optional[socket.socket],
enable_ssl_refresh: bool = False,
**uvicorn_kwargs: Any):
"""
Start a FastAPI app using Uvicorn, with support for custom Uvicorn config
options. Supports http header limits via h11_max_incomplete_event_size and
h11_max_header_count.
"""
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
path = getattr(route, "path", None)
if methods is None or path is None:
continue
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
# Extract header limit options if present
h11_max_incomplete_event_size = uvicorn_kwargs.pop(
"h11_max_incomplete_event_size", None)
h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None)
# Set safe defaults if not provided
if h11_max_incomplete_event_size is None:
h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
if h11_max_header_count is None:
h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT
config = uvicorn.Config(app, **uvicorn_kwargs)
# Set header limits
config.h11_max_incomplete_event_size = h11_max_incomplete_event_size
config.h11_max_header_count = h11_max_header_count
config.load()
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)
loop = asyncio.get_running_loop()
watchdog_task = loop.create_task(
watchdog_loop(server, app.state.engine_client))
server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None))
ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
ssl_context=config.ssl,
key_path=config.ssl_keyfile,
cert_path=config.ssl_certfile,
ca_path=config.ssl_ca_certs)
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
async def dummy_shutdown() -> None:
pass
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
return dummy_shutdown()
except asyncio.CancelledError:
port = uvicorn_kwargs["port"]
process = find_process_using_port(port)
if process is not None:
logger.warning(
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
watchdog_task.cancel()
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S = 5.0
while True:
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
terminate_if_errored(server, engine)
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError)
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

1629
vllm/entrypoints/llm.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
import torch
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__)
class RequestLogger:
def __init__(self, *, max_log_len: Optional[int]) -> None:
self.max_log_len = max_log_len
def log_inputs(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if prompt is not None:
prompt = prompt[:max_log_len]
if prompt_token_ids is not None:
prompt_token_ids = prompt_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, "
"lora_request: %s.", request_id, prompt, params, prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None,
lora_request)
def log_outputs(
self,
request_id: str,
outputs: str,
output_token_ids: Optional[Sequence[int]],
finish_reason: Optional[str] = None,
is_streaming: bool = False,
delta: bool = False,
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if outputs is not None:
outputs = outputs[:max_log_len]
if output_token_ids is not None:
# Convert to list and apply truncation
output_token_ids = list(output_token_ids)[:max_log_len]
stream_info = ""
if is_streaming:
stream_info = (" (streaming delta)"
if delta else " (streaming complete)")
logger.info(
"Generated response %s%s: output: %r, "
"output_token_ids: %s, finish_reason: %s",
request_id,
stream_info,
outputs,
output_token_ids,
finish_reason,
)

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,288 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from collections.abc import Sequence
from dataclasses import field
from typing import Literal, Optional, Union
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
class LoRAParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: list[LoRAModulePath] = []
for item in values:
if item in [None, ""]: # Skip if item is None or empty string
continue
if "=" in item and "," not in item: # Old format: name=path
name, path = item.split("=")
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(
f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
@config
@dataclass
class FrontendArgs:
"""Arguments for the OpenAI-compatible frontend server."""
host: Optional[str] = None
"""Host name."""
port: int = 8000
"""Port number."""
uds: Optional[str] = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical",
"trace"] = "info"
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
"""Allowed origins."""
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
"""Allowed methods."""
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
"""Allowed headers."""
api_key: Optional[list[str]] = None
"""If provided, the server will require one of these keys to be presented in
the header."""
lora_modules: Optional[list[LoRAModulePath]] = None
"""LoRA modules configurations in either 'name=path' format or JSON format
or JSON list format. Example (old format): `'name=path'` Example (new
format): `{\"name\": \"name\", \"path\": \"lora_path\",
\"base_model_name\": \"id\"}`"""
chat_template: Optional[str] = None
"""The file path to the chat template, or the template in single-line form
for the specified model."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None
"""The file path to the SSL key file."""
ssl_certfile: Optional[str] = None
"""The file path to the SSL cert file."""
ssl_ca_certs: Optional[str] = None
"""The CA certificates file."""
enable_ssl_refresh: bool = False
"""Refresh SSL Context when SSL certificate files change"""
ssl_cert_reqs: int = int(ssl.CERT_NONE)
"""Whether client certificate is required (see stdlib ssl module's)."""
root_path: Optional[str] = None
"""FastAPI root_path when app is behind a path based routing proxy."""
middleware: list[str] = field(default_factory=lambda: [])
"""Additional ASGI middleware to apply to the app. We accept multiple
--middleware arguments. The value should be an import path. If a function
is provided, vLLM will add it to the server using
`@app.middleware('http')`. If a class is provided, vLLM will
add it to the server using `app.add_middleware()`."""
return_tokens_as_token_ids: bool = False
"""When `--max-logprobs` is specified, represents single tokens as
strings of the form 'token_id:{token_id}' so that tokens that are not
JSON-encodable can be identified."""
disable_frontend_multiprocessing: bool = False
"""If specified, will run the OpenAI frontend server in the same process as
the model serving engine."""
enable_request_id_headers: bool = False
"""If specified, API server will add X-Request-Id header to responses."""
enable_auto_tool_choice: bool = False
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
to specify which parser to use."""
exclude_tools_when_tool_choice_none: bool = False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
tool_call_parser: Optional[str] = None
"""Select the tool call parser depending on the model that you're using.
This is used to parse the model-generated tool call into OpenAI API format.
Required for `--enable-auto-tool-choice`. You can choose any option from
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
tool_parser_plugin: str = ""
"""Special the tool parser plugin write to parse the model-generated tool
into OpenAI API format, the name register in this plugin can be used in
`--tool-call-parser`."""
tool_server: Optional[str] = None
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
purpose."""
log_config_file: Optional[str] = envs.VLLM_LOGGING_CONFIG_PATH
"""Path to logging config JSON file for both vllm and uvicorn"""
max_log_len: Optional[int] = None
"""Max number of prompt characters or prompt ID numbers being printed in
log. The default of None means unlimited."""
disable_fastapi_docs: bool = False
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
enable_prompt_tokens_details: bool = False
"""If set to True, enable prompt_tokens_details in usage."""
enable_server_load_tracking: bool = False
"""If set to True, enable tracking server_load_metrics in the app state."""
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the /get_tokenizer_info endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If True, log model outputs (generations).
Requires --enable-log-requests."""
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
"""Maximum number of HTTP headers allowed in a request for h11 parser.
Helps mitigate header abuse. Default: 256."""
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
"""If set to True, log the stack trace of error responses"""
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
from vllm.engine.arg_utils import get_kwargs
frontend_kwargs = get_kwargs(FrontendArgs)
# Special case: allowed_origins, allowed_methods, allowed_headers all
# need json.loads type
# Should also remove nargs
frontend_kwargs["allowed_origins"]["type"] = json.loads
frontend_kwargs["allowed_methods"]["type"] = json.loads
frontend_kwargs["allowed_headers"]["type"] = json.loads
del frontend_kwargs["allowed_origins"]["nargs"]
del frontend_kwargs["allowed_methods"]["nargs"]
del frontend_kwargs["allowed_headers"]["nargs"]
# Special case: LoRA modules need custom parser action and
# optional_type(str)
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Middleware needs to append action
frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str
if "nargs" in frontend_kwargs["middleware"]:
del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = []
# Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.tool_parsers.keys())
parsers_str = ",".join(valid_tool_parsers)
frontend_kwargs["tool_call_parser"]["metavar"] = (
f"{{{parsers_str}}} or name registered in --tool-parser-plugin")
frontend_group = parser.add_argument_group(
title="Frontend",
description=FrontendArgs.__doc__,
)
for key, value in frontend_kwargs.items():
frontend_group.add_argument(f"--{key.replace('_', '-')}", **value)
return parser
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Create the CLI argument parser used by the OpenAI API server.
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
register all arguments instead of manually enumerating them here. This
avoids code duplication and keeps the argument definitions in one place.
"""
parser.add_argument("model_tag",
type=str,
nargs="?",
help="The model tag to serve "
"(optional if specified in config)")
parser.add_argument(
"--headless",
action="store_true",
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
parser.add_argument("--api-server-count",
"-asc",
type=int,
default=1,
help="How many API server processes to run.")
parser.add_argument(
"--config",
help="Read CLI options from a config file. "
"Must be a YAML with the following options: "
"https://docs.vllm.ai/en/latest/configuration/serve_args.html")
parser = FrontendArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
if args.enable_log_outputs and not args.enable_log_requests:
raise TypeError("Error: --enable-log-outputs requires "
"--enable-log-requests")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)

View File

@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import lru_cache, partial
from typing import Optional, Union
import torch
from vllm.sampling_params import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
class AllowedTokenIdsLogitsProcessor:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""
def __init__(self, allowed_ids: Iterable[int]):
self.allowed_ids: Optional[list[int]] = list(allowed_ids)
self.mask: Optional[torch.Tensor] = None
def __call__(self, token_ids: list[int],
logits: torch.Tensor) -> torch.Tensor:
if self.mask is None:
self.mask = torch.ones((logits.shape[-1], ),
dtype=torch.bool,
device=logits.device)
self.mask[self.allowed_ids] = False
self.allowed_ids = None
logits.masked_fill_(self.mask, float("-inf"))
return logits
@lru_cache(maxsize=32)
def _get_allowed_token_ids_logits_processor(
allowed_token_ids: frozenset[int],
vocab_size: int,
) -> LogitsProcessor:
if not allowed_token_ids:
raise ValueError("Empty allowed_token_ids provided")
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
raise ValueError("allowed_token_ids contains "
"out-of-vocab token id")
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(
logit_bias: dict[int, float],
token_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
def get_logits_processors(
logit_bias: Optional[Union[dict[int, float], dict[str, float]]],
allowed_token_ids: Optional[list[int]],
tokenizer: AnyTokenizer,
) -> list[LogitsProcessor]:
logits_processors: list[LogitsProcessor] = []
if logit_bias:
try:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias: dict[int, float] = {
int(token_id): min(100.0, max(-100.0, bias))
for token_id, bias in logit_bias.items()
}
except ValueError as exc:
raise ValueError(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc
# Check if token_id is within the vocab size
for token_id, bias in clamped_logit_bias.items():
if token_id < 0 or token_id >= len(tokenizer):
raise ValueError(f"token_id {token_id} in logit_bias contains "
"out-of-vocab token id")
logits_processors.append(
partial(logit_bias_logits_processor, clamped_logit_bias))
if allowed_token_ids is not None:
logits_processors.append(
_get_allowed_token_ids_logits_processor(
frozenset(allowed_token_ids), len(tokenizer)))
return logits_processors

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,491 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import tempfile
from argparse import Namespace
from collections.abc import Awaitable
from http import HTTPStatus
from io import StringIO
from typing import Callable, Optional
import aiohttp
import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
EmbeddingResponse, ErrorResponse,
RerankResponse, ScoreResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help=
"The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.")
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.")
parser.add_argument(
"--output-tmp-dir",
type=str,
default=None,
help="The directory to store the output file before uploading it "
"to the output URL.",
)
parser.add_argument("--response-role",
type=optional_type(str),
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=True`.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument("--enable-metrics",
action="store_true",
help="Enable Prometheus metrics")
parser.add_argument(
"--url",
type=str,
default="0.0.0.0",
help="URL to the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
return parser
def parse_args():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
return make_arg_parser(parser).parse_args()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: Optional[tqdm] = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
self._pbar = tqdm(total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, encoding="utf-8") as f:
return f.read()
async def write_local_file(output_path: str,
batch_outputs: list[BatchRequestOutput]) -> None:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't affect performance.
with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs:
print(o.model_dump_json(), file=f)
async def upload_data(output_url: str, data_or_file: str,
from_file: bool) -> None:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries = 5
# Number of seconds to wait before retrying.
delay = 5
for attempt in range(1, max_retries + 1):
try:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
total=1000)) as session:
if from_file:
with open(data_or_file, "rb") as file:
async with session.put(output_url,
data=file) as response:
if response.status != 200:
raise Exception(f"Failed to upload file.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}")
else:
async with session.put(output_url,
data=data_or_file) as response:
if response.status != 200:
raise Exception(f"Failed to upload data.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}")
except Exception as e:
if attempt < max_retries:
logger.error(
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
attempt,
e,
delay,
)
await asyncio.sleep(delay)
else:
raise Exception(
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput],
output_tmp_dir: str) -> None:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
if output_tmp_dir is None:
logger.info("Writing outputs to memory buffer")
output_buffer = StringIO()
for o in batch_outputs:
print(o.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(
path_or_url,
output_buffer.read().strip().encode("utf-8"),
from_file=False,
)
else:
# Write responses to a temporary file and then upload it to the URL.
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=output_tmp_dir,
prefix="tmp_batch_output_",
suffix=".jsonl",
) as f:
logger.info("Writing outputs to temporary local file %s",
f.name)
await write_local_file(f.name, batch_outputs)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(path_or_url, f.name, from_file=True)
else:
logger.info("Writing outputs to local file %s", path_or_url)
await write_local_file(path_or_url, batch_outputs)
def make_error_request_output(request: BatchRequestInput,
error_msg: str) -> BatchRequestOutput:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=HTTPStatus.BAD_REQUEST,
request_id=f"vllm-batch-{random_uuid()}",
),
error=error_msg,
)
return batch_output
async def make_async_error_request_output(
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
return make_error_request_output(request, error_msg)
async def run_request(serving_engine_func: Callable,
request: BatchRequestInput,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(
response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
RerankResponse),
):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=response, request_id=f"vllm-batch-{random_uuid()}"),
error=None,
)
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=response.error.code,
request_id=f"vllm-batch-{random_uuid()}"),
error=response,
)
else:
batch_output = make_error_request_output(
request, error_msg="Request must not be sent in stream mode")
tracker.completed()
return batch_output
async def run_batch(
engine_client: EngineClient,
vllm_config: VllmConfig,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
model_config = vllm_config.model_config
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported_tasks: %s", supported_tasks)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
)
openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if "generate" in supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if "embed" in supported_tasks else None
enable_serving_reranking = ("classify" in supported_tasks and getattr(
model_config.hf_config, "num_labels", 0) == 1)
openai_serving_scores = ServingScores(
engine_client,
model_config,
openai_serving_models,
request_logger=request_logger,
) if ("embed" in supported_tasks or enable_serving_reranking) else None
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: list[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
chat_handler_fn = openai_serving_chat.create_chat_completion if \
openai_serving_chat is not None else None
if chat_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg=
"The model does not support Chat Completions API",
))
continue
response_futures.append(
run_request(chat_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
embed_handler_fn = openai_serving_embedding.create_embedding if \
openai_serving_embedding is not None else None
if embed_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Embeddings API",
))
continue
response_futures.append(
run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/score"):
score_handler_fn = openai_serving_scores.create_score if \
openai_serving_scores is not None else None
if score_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Scores API",
))
continue
response_futures.append(
run_request(score_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/rerank"):
rerank_handler_fn = openai_serving_scores.do_rerank if \
openai_serving_scores is not None else None
if rerank_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Rerank API",
))
continue
response_futures.append(
run_request(rerank_handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /score, /rerank ."
"See vllm/entrypoints/openai/api_server.py for supported "
"score/rerank versions.",
))
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
await write_file(args.output_file, responses, args.output_tmp_dir)
async def main(args: Namespace):
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.usage.usage_lib import UsageContext
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
vllm_config = await engine_client.get_vllm_config()
await run_batch(engine_client, vllm_config, args)
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
logger.info("args: %s", args)
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(main(args))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import override
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ClassificationData,
ClassificationRequest,
ClassificationResponse,
ErrorResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
OpenAIServing,
ServeContext)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing):
@override
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
if isinstance(ctx.request.input, str) and not ctx.request.input:
return self.create_error_response(
"Input cannot be empty for classification",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
return None
try:
ctx.tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request))
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@override
def _build_response(
self,
ctx: ServeContext,
) -> Union[ClassificationResponse, ErrorResponse]:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
ctx = cast(ClassificationServeContext, ctx)
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput],
ctx.final_res_batch)
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label",
{}).get(predicted_index)
item = ClassificationData(
index=idx,
label=label,
probs=probs,
num_classes=len(probs),
)
items.append(item)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ClassificationResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
def _build_render_config(self,
request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> Union[ClassificationResponse, ErrorResponse]:
model_name = self.models.model_name()
request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}")
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
) -> Union[PoolingParams, ErrorResponse]:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("classify", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@@ -0,0 +1,692 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Optional, Union, cast
import jinja2
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import as_list, merge_async_iterators
logger = init_logger(__name__)
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
log_error_stack=log_error_stack,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info(
"Using default completion sampling params from %s: %s",
source,
self.default_sampling_params,
)
async def create_completion(
self,
request: CompletionRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.echo and request.prompt_embeds is not None:
return self.create_error_response(
"Echo is unsupported with prompt embeds.")
if (request.prompt_logprobs is not None
and request.prompt_embeds is not None):
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds.")
request_id = (
f"cmpl-"
f"{self._base_request_id(raw_request, request.request_id)}")
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
# Mypy does not infer that engine_prompt will have only one of
# "prompt_token_ids" or "prompt_embeds" defined, and both of
# these as Union[object, the expected type], where it infers
# object if engine_prompt is a subclass of one of the
# typeddicts that defines both keys. Worse, because of
# https://github.com/python/mypy/issues/8586, mypy does not
# infer the type of engine_prompt correctly because of the
# enumerate. So we need an unnecessary cast here.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if is_embeds_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_embeds"])
elif is_tokens_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_token_ids"])
else:
assert_never(engine_prompt)
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params,
)
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. Noting that best_of is only supported in V0. In addition,
# we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:
return self.completion_stream_generator(
request,
engine_prompts,
result_generator,
request_id,
created_time,
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage,
)
# Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = None if is_embeds_prompt(
engine_prompt) else engine_prompt.get("prompt")
final_res_batch_checked = cast(list[RequestOutput],
final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options
if stream_options:
include_usage = (stream_options.include_usage
or enable_force_include_usage)
include_continuous_usage = (include_usage and
stream_options.continuous_usage_stats)
else:
include_usage, include_continuous_usage = False, False
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = None if is_embeds_prompt(
engine_prompt) else engine_prompt.get("prompt")
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[dict[
int, Logprob]]]]
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# Useful when request.return_token_ids is True
# Returning prompt token IDs shares the same logic
# with the echo implementation.
prompt_token_ids_to_return: Optional[list[int]] = None
assert request.max_tokens is not None
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
else:
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids,
*output.token_ids,
]
out_logprobs = [
*(prompt_logprobs or []),
*(output.logprobs or []),
]
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
# has_echoed[i] is reused here to indicate whether
# we have already returned the prompt token IDs.
if not has_echoed[i]:
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
if (not delta_text and not delta_token_ids
and not previous_num_tokens[i]):
# Chunked prefill case, don't return empty chunks
continue
if request.logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=previous_text_lens[i],
return_as_token_id=request.
return_tokens_as_token_ids,
)
else:
logprobs = None
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
prompt_token_ids=prompt_token_ids_to_return,
token_ids=(as_list(output.token_ids) if
request.return_token_ids else None),
)
],
)
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=final_usage_info,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except Exception as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: list[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch:
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[dict[int,
Logprob]]]]
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo:
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
else:
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
prompt_token_ids=(prompt_token_ids
if request.return_token_ids else None),
token_ids=(as_list(output.token_ids)
if request.return_token_ids else None),
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if (self.enable_prompt_tokens_details and last_final_res
and last_final_res.num_cached_tokens):
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=last_final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=kv_transfer_params,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: AnyTokenizer,
initial_text_offset: int = 0,
return_as_token_id: Optional[bool] = None,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: list[int] = []
out_token_logprobs: list[Optional[float]] = []
out_tokens: list[str] = []
out_top_logprobs: list[Optional[dict[str, float]]] = []
last_token_len = 0
should_return_as_token_id = (return_as_token_id
if return_as_token_id is not None else
self.return_tokens_as_token_ids)
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_token,
token_id,
tokenizer,
return_as_token_id=should_return_as_token_id,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
})
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: Optional[int] = None,
) -> RenderConfig:
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
)

View File

@@ -0,0 +1,631 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, Literal, Optional, Union, cast
import numpy as np
import torch
from fastapi import Request
from typing_extensions import assert_never, override
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
OpenAIServing,
ServeContext,
TextTokensPrompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingOutput, PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.utils import chunk_list
logger = init_logger(__name__)
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[list[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pooler_config = self.model_config.pooler_config
# Avoid repeated attribute lookups
self.supports_chunked_processing = bool(
pooler_config and pooler_config.enable_chunked_processing)
self.max_embed_len = (pooler_config.max_embed_len if pooler_config
and pooler_config.max_embed_len else None)
@override
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
ctx = cast(EmbeddingServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest):
(
_,
_,
ctx.engine_prompts,
) = await self._preprocess_chat(
ctx.request,
tokenizer,
ctx.request.messages,
chat_template=ctx.request.chat_template
or ctx.chat_template,
chat_template_content_format=ctx.
chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=False,
add_special_tokens=ctx.request.add_special_tokens,
)
else:
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
)
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_render_config(
self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)
@override
def _build_response(
self,
ctx: ServeContext,
) -> Union[EmbeddingResponse, ErrorResponse]:
items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput],
ctx.final_res_batch)
for idx, final_res in enumerate(final_res_batch_checked):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
ctx.request.encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
def _get_max_position_embeddings(self) -> int:
"""Get the model's effective maximum sequence length for chunking."""
return self.model_config.max_model_len
def _should_use_chunked_processing(self, request) -> bool:
"""Check if chunked processing should be used for this request."""
return isinstance(
request,
(EmbeddingCompletionRequest,
EmbeddingChatRequest)) and self.supports_chunked_processing
async def _process_chunked_request(
self,
ctx: EmbeddingServeContext,
original_prompt: TextTokensPrompt,
pooling_params,
trace_headers,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
token_ids = original_prompt["prompt_token_ids"]
# Split into chunks using max_position_embeddings
max_pos_embeddings = self._get_max_position_embeddings()
# Process all chunks for MEAN aggregation
for chunk_idx, chunk_tokens in enumerate(
chunk_list(token_ids, max_pos_embeddings)):
# Create a request ID for this chunk
chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-"
f"chunk-{chunk_idx}")
# Create engine prompt for this chunk
chunk_engine_prompt = EngineTokensPrompt(
prompt_token_ids=chunk_tokens)
# Create chunk request prompt for logging
chunk_text = ""
chunk_request_prompt = TextTokensPrompt(
prompt=chunk_text, prompt_token_ids=chunk_tokens)
# Log the chunk
self._log_inputs(chunk_request_id,
chunk_request_prompt,
params=pooling_params,
lora_request=ctx.lora_request)
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(original_generator)
return generators
def _validate_input(
self,
request,
input_ids: list[int],
input_text: str,
) -> TextTokensPrompt:
"""Override to support chunked processing for embedding requests."""
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request,
(EmbeddingCompletionRequest, EmbeddingChatRequest)):
# Check if chunked processing is enabled for pooling models
enable_chunked = self._should_use_chunked_processing(request)
# Use max_position_embeddings for chunked processing decisions
max_pos_embeddings = self._get_max_position_embeddings()
# Determine the effective max length for validation
if self.max_embed_len is not None:
# Use max_embed_len for validation instead of max_model_len
length_type = "maximum embedding input length"
max_length_value = self.max_embed_len
else:
# Fall back to max_model_len validation (original behavior)
length_type = "maximum context length"
max_length_value = self.max_model_len
validation_error_msg = (
"This model's {length_type} is {max_length_value} tokens. "
"However, you requested {token_num} tokens in the input for "
"embedding generation. Please reduce the length of the input.")
chunked_processing_error_msg = (
"This model's {length_type} is {max_length_value} tokens. "
"However, you requested {token_num} tokens in the input for "
"embedding generation. Please reduce the length of the input "
"or enable chunked processing.")
# Check if input exceeds max length
if token_num > max_length_value:
raise ValueError(
validation_error_msg.format(
length_type=length_type,
max_length_value=max_length_value,
token_num=token_num))
# Check for chunked processing
# when exceeding max_position_embeddings
if token_num > max_pos_embeddings:
if enable_chunked:
# Allow long inputs when chunked processing is enabled
logger.info(
"Input length %s exceeds max_position_embeddings "
"%s, will use chunked processing", token_num,
max_pos_embeddings)
else:
raise ValueError(
chunked_processing_error_msg.format(
length_type="maximum position embeddings length",
max_length_value=max_pos_embeddings,
token_num=token_num))
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# For other request types, use the parent's implementation
return super()._validate_input(request, input_ids, input_text)
def _is_text_tokens_prompt(self, prompt) -> bool:
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt)
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: EngineTokensPrompt,
pooling_params: PoolingParams,
trace_headers: Optional[Mapping[str, str]],
prompt_index: int,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
"""Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}"
self._log_inputs(request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request)
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
@override
async def _prepare_generators(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
# If no chunked processing needed, delegate to parent class
if not use_chunked:
return await super()._prepare_generators(ctx)
# Custom logic for chunked processing
generators: list[AsyncGenerator[Union[RequestOutput,
PoolingRequestOutput],
None]] = []
try:
trace_headers = (None if ctx.raw_request is None else await
self._get_trace_headers(ctx.raw_request.headers))
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
# Verify and set the task for pooling params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
max_pos_embeddings = self._get_max_position_embeddings()
for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(engine_prompt):
# Cast to TextTokensPrompt since we've verified
# prompt_token_ids
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if (len(text_tokens_prompt["prompt_token_ids"])
> max_pos_embeddings):
# Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request(
ctx, text_tokens_prompt, pooling_params,
trace_headers, i)
generators.extend(chunk_generators)
continue
# Normal processing for short prompts or non-token prompts
generator = await self._create_single_prompt_generator(
ctx, engine_prompt, pooling_params, trace_headers, i)
generators.append(generator)
from vllm.utils import merge_async_iterators
ctx.result_generator = merge_async_iterators(*generators)
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@override
async def _collect_batch(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Collect and aggregate batch results
with support for chunked processing.
For chunked requests, performs online aggregation to
minimize memory usage.
For regular requests, collects results normally.
"""
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
# Check if we used chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
if not use_chunked:
return await super()._collect_batch(ctx=ctx)
if ctx.result_generator is None:
return self.create_error_response(
"Result generator not available")
# Online aggregation for chunked requests to
# minimize memory usage
# Track aggregation state for each prompt
prompt_aggregators: dict[int, dict[str, Any]] = {}
short_prompts_results: dict[int, PoolingRequestOutput] = {}
async for result_idx, result in ctx.result_generator:
if "-chunk-" in result.request_id:
# Extract prompt_idx from chunked request_id
parts = result.request_id.split("-")
try:
prompt_idx = int(parts[parts.index("prompt") + 1])
except (ValueError, IndexError):
# Fallback: extract from result_idx if parsing fails
prompt_idx = result_idx
# Initialize aggregator for this prompt if needed
if prompt_idx not in prompt_aggregators:
prompt_aggregators[prompt_idx] = {
'weighted_sum': None,
'total_weight': 0,
'chunk_count': 0,
'request_id': result.request_id.split("-chunk-")[0]
}
aggregator = prompt_aggregators[prompt_idx]
# MEAN pooling with online weighted averaging
# Ensure result is PoolingRequestOutput
# for embedding processing
if not isinstance(result, PoolingRequestOutput):
return self.create_error_response(
f"Expected PoolingRequestOutput for "
f"chunked embedding, got "
f"{type(result).__name__}")
# Handle both PoolingOutput and
# EmbeddingOutput types
if hasattr(result.outputs, 'data'):
# PoolingOutput case
embedding_data = result.outputs.data
elif hasattr(result.outputs, 'embedding'):
# EmbeddingOutput case -
# convert embedding list to tensor
embedding_data = result.outputs.embedding
else:
return self.create_error_response(
f"Unsupported output type: "
f"{type(result.outputs).__name__}")
if not isinstance(embedding_data, torch.Tensor):
embedding_data = torch.tensor(embedding_data,
dtype=torch.float32)
if result.prompt_token_ids is None:
return self.create_error_response(
"prompt_token_ids cannot be None for "
"chunked processing")
weight = len(result.prompt_token_ids)
weighted_embedding = embedding_data.to(
dtype=torch.float32) * weight
if aggregator['weighted_sum'] is None:
# First chunk
aggregator['weighted_sum'] = weighted_embedding
else:
# Accumulate
aggregator['weighted_sum'] += weighted_embedding
aggregator['total_weight'] += weight
aggregator['chunk_count'] += 1
else:
# Non-chunked result - extract prompt_idx from request_id
parts = result.request_id.split("-")
try:
# Last part should be prompt index
prompt_idx = int(parts[-1])
except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result)
# Finalize aggregated results
final_res_batch: list[Union[PoolingRequestOutput,
EmbeddingRequestOutput]] = []
num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts):
if prompt_idx in prompt_aggregators:
# Finalize MEAN aggregation for this chunked prompt
aggregator = prompt_aggregators[prompt_idx]
weighted_sum = aggregator['weighted_sum']
total_weight = aggregator['total_weight']
if (weighted_sum is not None
and isinstance(weighted_sum, torch.Tensor)
and isinstance(total_weight,
(int, float)) and total_weight > 0):
# Compute final mean embedding
final_embedding = weighted_sum / total_weight
# Create a PoolingRequestOutput
# for the aggregated result
pooling_output_data = PoolingOutput(
data=final_embedding)
# Get original prompt token IDs for this prompt
original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt):
return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a "
f"TextTokensPrompt")
original_token_ids = cast(
TextTokensPrompt,
original_prompt)["prompt_token_ids"]
pooling_request_output = PoolingRequestOutput(
request_id=aggregator['request_id'],
prompt_token_ids=original_token_ids,
outputs=pooling_output_data,
finished=True)
final_res_batch.append(pooling_request_output)
else:
return self.create_error_response(
f"Failed to aggregate chunks "
f"for prompt {prompt_idx}")
elif prompt_idx in short_prompts_results:
final_res_batch.append(
cast(PoolingRequestOutput,
short_prompts_results[prompt_idx]))
else:
return self.create_error_response(
f"Result not found for prompt {prompt_idx}")
ctx.final_res_batch = cast(
list[Union[RequestOutput, PoolingRequestOutput]],
final_res_batch)
return None
except Exception as e:
return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
model_name = self.models.model_name()
request_id = (
f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request, request.request_id)}")
ctx = EmbeddingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params(
self,
ctx: ServeContext[EmbeddingRequest],
) -> Union[PoolingParams, ErrorResponse]:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@@ -0,0 +1,992 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field
from starlette.datastructures import Headers
from typing_extensions import TypeIs
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
ClassificationRequest,
ClassificationResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorInfo,
ErrorResponse,
IOProcessorRequest,
PoolingResponse, RerankRequest,
ResponsesRequest, ScoreRequest,
ScoreResponse,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
TranslationRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
RenderConfig)
# yapf: enable
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
MultiModalDataDict, MultiModalUUIDDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
merge_async_iterators, random_uuid)
logger = init_logger(__name__)
CompletionLikeRequest = Union[
CompletionRequest,
DetokenizeRequest,
EmbeddingCompletionRequest,
RerankRequest,
ClassificationRequest,
ScoreRequest,
TokenizeCompletionRequest,
]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[
CompletionLikeRequest,
ChatLikeRequest,
SpeechToTextRequest,
ResponsesRequest,
IOProcessorRequest,
]
AnyResponse = Union[
CompletionResponse,
ChatCompletionResponse,
EmbeddingResponse,
TranscriptionResponse,
TokenizeResponse,
PoolingResponse,
ClassificationResponse,
ScoreResponse,
]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: list[int]
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt)
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt)
RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel):
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Optional[Sequence[RequestPrompt]] = []
engine_prompts: Optional[list[EngineTokensPrompt]] = []
model_config = ConfigDict(arbitrary_types_allowed=True)
class ResponseGenerationMixin(BaseModel):
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: Optional[AsyncGenerator[tuple[int, Union[
RequestOutput, PoolingRequestOutput]], None]] = None
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(
RequestProcessingMixin,
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
# Shared across all requests
request: RequestT
raw_request: Optional[Request] = None
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
lora_request: Optional[LoRARequest] = None
# Shared across most requests
tokenizer: Optional[AnyTokenizer] = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
ClassificationServeContext = ServeContext[ClassificationRequest]
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: Optional[str] = None
chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
so you can easily tell “this ID came from Embedding vs Classification.”
"""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.models = models
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._async_tokenizer_pool: dict[AnyTokenizer,
AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool)
def _build_render_config(
self,
request: Any,
) -> RenderConfig:
"""
Build and return a `RenderConfig` for an endpoint.
Used by the renderer to control how prompts are prepared
(e.g., tokenization and length handling). Endpoints should
implement this with logic appropriate to their request type.
"""
raise NotImplementedError
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self._async_tokenizer_pool[tokenizer] = async_tokenizer
return async_tokenizer
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""
Default preprocessing hook. Subclasses may override
to prepare `ctx` (classification, embedding, etc.).
"""
return None
def _build_response(
self,
ctx: ServeContext,
) -> Union[AnyResponse, ErrorResponse]:
"""
Default response builder. Subclass may override this method
to return the appropriate response object.
"""
return self.create_error_response("unimplemented endpoint")
async def handle(
self,
ctx: ServeContext,
) -> Union[AnyResponse, ErrorResponse]:
generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
generation = self._pipeline(ctx)
async for response in generation:
return response
return self.create_error_response("No response yielded from pipeline")
async def _pipeline(
self,
ctx: ServeContext,
) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
"""Execute the request processing pipeline yielding responses."""
if error := await self._check_model(ctx.request):
yield error
if error := self._validate_request(ctx):
yield error
preprocess_ret = await self._preprocess(ctx)
if isinstance(preprocess_ret, ErrorResponse):
yield preprocess_ret
generators_ret = await self._prepare_generators(ctx)
if isinstance(generators_ret, ErrorResponse):
yield generators_ret
collect_ret = await self._collect_batch(ctx)
if isinstance(collect_ret, ErrorResponse):
yield collect_ret
yield self._build_response(ctx)
def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
None)
if (truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.max_model_len):
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
return None
def _create_pooling_params(
self,
ctx: ServeContext,
) -> Union[PoolingParams, ErrorResponse]:
if not hasattr(ctx.request, "to_pooling_params"):
return self.create_error_response(
"Request type does not support pooling parameters")
return ctx.request.to_pooling_params()
async def _prepare_generators(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[Union[RequestOutput,
PoolingRequestOutput],
None]] = []
try:
trace_headers = (None if ctx.raw_request is None else await
self._get_trace_headers(ctx.raw_request.headers))
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def _collect_batch(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Collect batch results from the result generator."""
try:
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[Optional[Union[RequestOutput,
PoolingRequestOutput]]]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
return self.create_error_response(
"Result generator not available")
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts")
ctx.final_res_batch = [
res for res in final_res_batch if res is not None
]
return None
except Exception as e:
return self.create_error_response(str(e))
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
if self.log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(error=ErrorInfo(
message=message, type=err_type, code=status_code.value))
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
json_str = json.dumps(
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump())
return json_str
async def _check_model(
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
error_response = None
if self._is_model_supported(request.model):
return None
if request.model in self.models.lora_requests:
return None
if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
(load_result := await self.models.resolve_lora(request.model))):
if isinstance(load_result, LoRARequest):
return None
if (isinstance(load_result, ErrorResponse) and
load_result.error.code == HTTPStatus.BAD_REQUEST.value):
error_response = load_result
return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
)
def _get_active_default_mm_loras(
self, request: AnyRequest) -> Optional[LoRARequest]:
"""Determine if there are any active default multimodal loras."""
# TODO: Currently this is only enabled for chat completions
# to be better aligned with only being enabled for .generate
# when run offline. It would be nice to support additional
# tasks types in the future.
message_types = self._get_message_types(request)
default_mm_loras = set()
for lora in self.models.lora_requests.values():
# Best effort match for default multimodal lora adapters;
# There is probably a better way to do this, but currently
# this matches against the set of 'types' in any content lists
# up until '_', e.g., to match audio_url -> audio
if lora.lora_name in message_types:
default_mm_loras.add(lora)
# Currently only support default modality specific loras if
# we have exactly one lora matched on the request.
if len(default_mm_loras) == 1:
return default_mm_loras.pop()
return None
def _maybe_get_adapters(
self,
request: AnyRequest,
supports_default_mm_loras: bool = False,
) -> Optional[LoRARequest]:
if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
return default_mm_lora
if self._is_model_supported(request.model):
return None
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _get_message_types(self, request: AnyRequest) -> set[str]:
"""Retrieve the set of types from message content dicts up
until `_`; we use this to match potential multimodal data
with default per modality loras.
"""
message_types: set[str] = set()
if not hasattr(request, "messages"):
return message_types
for message in request.messages:
if (isinstance(message, dict) and "content" in message
and isinstance(message["content"], list)):
for content_dict in message["content"]:
if "type" in content_dict:
message_types.add(content_dict["type"].split("_")[0])
return message_types
async def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
prompt: str,
tokenizer: AnyTokenizer,
add_special_tokens: bool,
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
prompt = prompt.lower()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens)
elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len,
)
else:
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens,
)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
async def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: Optional[AnyTokenizer],
) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
input_ids = prompt_ids[-self.max_model_len:]
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
if tokenizer is None:
input_text = ""
else:
async_tokenizer = self._get_async_tokenizer(tokenizer)
input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: list[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest, ClassificationRequest,
# and ScoreRequest doesn't have max_tokens
if isinstance(
request,
(
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ScoreRequest,
RerankRequest,
ClassificationRequest,
),
):
# Note: input length can be up to the entire model context length
# since these requests don't generate tokens.
if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score",
ClassificationRequest: "classification",
}
operation = operations.get(type(request),
"embedding generation")
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(
request,
(TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest),
):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = getattr(request, "max_tokens", None)
# Note: input length can be up to model context length - 1 for
# completion-like requests.
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of "
"the input messages.")
if (max_tokens is not None
and token_num + max_tokens > self.max_model_len):
raise ValueError(
"'max_tokens' or 'max_completion_tokens' is too large: "
f"{max_tokens}. This model's maximum context length is "
f"{self.max_model_len} tokens and your request has "
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
f" - {token_num}).")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, list[int]],
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation that tokenizes a single prompt input.
"""
async for result in self._tokenize_prompt_inputs_async(
request,
tokenizer,
[prompt_input],
add_special_tokens=add_special_tokens,
):
return result
raise ValueError("No results yielded from tokenization")
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, list[int]]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]:
"""
A simpler implementation that tokenizes multiple prompt inputs.
"""
for prompt in prompt_inputs:
if isinstance(prompt, str):
yield await self._normalize_prompt_text_to_input(
request,
prompt=prompt,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
)
else:
yield await self._normalize_prompt_tokens_to_input(
request,
prompt_ids=prompt,
tokenizer=tokenizer,
)
async def _preprocess_chat(
self,
request: Union[ChatLikeRequest, ResponsesRequest],
tokenizer: AnyTokenizer,
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[list[dict[str, Any]]] = None,
documents: Optional[list[dict[str, str]]] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
add_special_tokens: bool = False,
) -> tuple[
list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tool_dicts,
chat_template_content_format,
tokenizer,
model_config=model_config,
)
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
messages,
model_config,
tokenizer,
content_format=resolved_content_format,
)
_chat_template_kwargs: dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
)
_chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: Union[str, list[int]]
if tokenizer is None:
request_prompt = "placeholder"
elif isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
**_chat_template_kwargs,
)
else:
request_prompt = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
)
mm_data = await mm_data_future
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
should_parse_tools = tool_parser is not None and (hasattr(
request, "tool_choice") and request.tool_choice != "none")
if should_parse_tools:
if not isinstance(request, ChatCompletionRequest):
msg = "Tool usage is only supported for Chat Completions API"
raise NotImplementedError(msg)
request = tool_parser(tokenizer).adjust_request( # type: ignore
request=request)
if tokenizer is None:
assert isinstance(request_prompt, str), (
"Prompt has to be a string",
"when the tokenizer is not initialised",
)
prompt_inputs = TextTokensPrompt(prompt=request_prompt,
prompt_token_ids=[1])
elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
request_prompt,
add_special_tokens=add_special_tokens,
)
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt,
)
engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return conversation, [request_prompt], [engine_prompt]
async def _generate_with_builtin_tools(
self,
request_id: str,
request_prompt: RequestPrompt,
engine_prompt: EngineTokensPrompt,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
**kwargs,
):
orig_priority = priority
while True:
self._log_inputs(
request_id,
request_prompt,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
priority=priority,
**kwargs,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_output(tool_output)
# TODO: uncomment this and enable tool output streaming
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids.
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(
prompt_token_ids)
# OPTIMIZATION
priority = orig_priority - 1
def _log_inputs(
self,
request_id: str,
inputs: Union[RequestPrompt, PromptType],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
) -> None:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = None, None, None
if isinstance(inputs, str):
prompt = inputs
elif isinstance(inputs, list):
prompt_token_ids = inputs
else:
prompt = getattr(inputs, 'prompt', None)
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
params=params,
lora_request=lora_request,
)
async def _get_trace_headers(
self,
headers: Headers,
) -> Optional[Mapping[str, str]]:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
@staticmethod
def _base_request_id(raw_request: Optional[Request],
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
if raw_request is None:
return default
return raw_request.headers.get("X-Request-Id", default)
@staticmethod
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
return_as_token_id: bool = False,
) -> str:
if return_as_token_id:
return f"token_id:{token_id}"
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: Optional[str]) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)
def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs,
None], ) -> Union[PromptLogprobs, None]:
if prompt_logprobs is None:
return prompt_logprobs
for logprob_dict in prompt_logprobs:
if logprob_dict is None:
continue
for logprob_values in logprob_dict.values():
if logprob_values.logprob == float("-inf"):
logprob_values.logprob = -9999.0
return prompt_logprobs

View File

@@ -0,0 +1,288 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from asyncio import Lock
from collections import defaultdict
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse,
LoadLoRAAdapterRequest,
ModelCard, ModelList,
ModelPermission,
UnloadLoRAAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.
Handles the routes:
- /v1/models
- /v1/load_lora_adapter
- /v1/unload_lora_adapter
"""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: list[BaseModelPath],
*,
lora_modules: Optional[list[LoRAModulePath]] = None,
):
super().__init__()
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.engine_client = engine_client
self.model_config = model_config
self.static_lora_modules = lora_modules
self.lora_requests: dict[str, LoRARequest] = {}
self.lora_id_counter = AtomicCounter(0)
self.lora_resolvers: list[LoRAResolver] = []
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
):
self.lora_resolvers.append(
LoRAResolverRegistry.get_resolver(lora_resolver_name))
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
async def init_static_loras(self):
"""Loads all static LoRA modules.
Raises if any fail to load"""
if self.static_lora_modules is None:
return
for lora in self.static_lora_modules:
load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
lora_name=lora.name)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name)
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all
adapters"""
model_cards = [
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()])
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests.values()
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
async def load_lora_adapter(
self,
request: LoadLoRAAdapterRequest,
base_model_name: Optional[str] = None
) -> Union[ErrorResponse, str]:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_load_lora_adapter_request(
request)
if error_check_ret is not None:
return error_check_ret
lora_path = request.lora_path
unique_id = self.lora_id_counter.inc(1)
lora_request = LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path)
if base_model_name is not None and self.is_base_model(
base_model_name):
lora_request.base_model_name = base_model_name
# Validate that the adapter can be loaded into the engine
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except Exception as e:
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
return create_error_response(message=str(e),
err_type=error_type,
status_code=status_code)
self.lora_requests[lora_name] = lora_request
logger.info("Loaded new LoRA adapter: name '%s', path '%s'",
lora_name, lora_path)
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self,
request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_unload_lora_adapter_request(
request)
if error_check_ret is not None:
return error_check_ret
# Safe to delete now since we hold the lock
del self.lora_requests[lora_name]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name already exists
if request.lora_name in self.lora_requests:
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been "
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
# Check if 'lora_name' is not provided return an error
if not request.lora_name:
return create_error_response(
message=
"'lora_name' needs to be provided to unload a LoRA adapter.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name exists
if request.lora_name not in self.lora_requests:
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
return None
async def resolve_lora(
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
"""Attempt to resolve a LoRA adapter using available resolvers.
Args:
lora_name: Name/identifier of the LoRA adapter
Returns:
LoRARequest if found and loaded successfully.
ErrorResponse (404) if no resolver finds the adapter.
ErrorResponse (400) if adapter(s) are found but none load.
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
if lora_name in self.lora_requests:
return self.lora_requests[lora_name]
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
found_adapter = False
# Try to resolve using available resolvers
for resolver in self.lora_resolvers:
lora_request = await resolver.resolve_lora(
base_model_name, lora_name)
if lora_request is not None:
found_adapter = True
lora_request.lora_int_id = unique_id
try:
await self.engine_client.add_lora(lora_request)
self.lora_requests[lora_name] = lora_request
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name, resolver.__class__.__name__)
return lora_request
except BaseException as e:
logger.warning(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver.", lora_name,
resolver.__class__.__name__, e)
continue
if found_adapter:
# An adapter was found, but all attempts to load it failed.
return create_error_response(
message=(f"LoRA adapter '{lora_name}' was found "
"but could not be loaded."),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
else:
# No adapter was found
return create_error_response(
message=f"LoRA adapter {lora_name} does not exist",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(error=ErrorInfo(
message=message, type=err_type, code=status_code.value))

View File

@@ -0,0 +1,276 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import time
from collections.abc import AsyncGenerator
from typing import Final, Literal, Optional, Union, cast
import jinja2
import numpy as np
import torch
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
def _get_data(
output: PoolingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[list[float], str]:
if encoding_format == "float":
return output.data.tolist()
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
pt_float32 = output.data.to(dtype=torch.float32)
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
return base64.b64encode(pooling_bytes).decode("utf-8")
assert_never(encoding_format)
class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
vllm_config: VllmConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=vllm_config.model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
is_io_processor_request = isinstance(request, IOProcessorRequest)
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported")
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
elif isinstance(request, PoolingChatRequest):
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
# In pooling requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
add_special_tokens=request.add_special_tokens,
)
elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
config=self._build_render_config(request),
)
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("encode", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
engine_prompt,
params=pooling_params,
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
assert isinstance(request,
(PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: list[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(list[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_pooling_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
request.encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_pooling_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> PoolingResponse:
items: list[PoolingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=_get_data(final_res.outputs, encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return PoolingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
def _build_render_config(
self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,479 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Optional, Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
RerankRequest, RerankResponse,
RerankResult, RerankUsage,
ScoreRequest, ScoreResponse,
ScoreResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
compress_token_type_ids,
get_score_prompt)
# yapf: enable
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
class ServingScores(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
async def _embedding_score(
self,
tokenizer: AnyTokenizer,
texts_1: list[str],
texts_2: list[str],
request: Union[RerankRequest, ScoreRequest],
request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
input_texts = texts_1 + texts_2
engine_prompts: list[TokensPrompt] = []
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
tokenization_kwargs = tokenization_kwargs or {}
tokenized_prompts = await asyncio.gather(
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts))
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = \
self._validate_input(
request,
tok_result["input_ids"],
input_text)
engine_prompts.append(
TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"]))
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
input_texts[i],
params=pooling_params,
lora_request=lora_request)
generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
))
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput] = []
embeddings: list[Optional[PoolingRequestOutput]] =\
[None] * len(engine_prompts)
async for i, res in result_generator:
embeddings[i] = res
emb_texts_1: list[PoolingRequestOutput] = []
emb_texts_2: list[PoolingRequestOutput] = []
for i in range(0, len(texts_1)):
assert (emb := embeddings[i]) is not None
emb_texts_1.append(emb)
for i in range(len(texts_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_texts_2.append(emb)
if len(emb_texts_1) == 1:
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
embed_1=emb_texts_1,
embed_2=emb_texts_2)
return final_res_batch
def _preprocess_score(
self,
request: Union[RerankRequest, ScoreRequest],
tokenizer: AnyTokenizer,
tokenization_kwargs: dict[str, Any],
data_1: Union[str, ScoreContentPartParam],
data_2: Union[str, ScoreContentPartParam],
) -> tuple[str, TokensPrompt]:
model_config = self.model_config
full_prompt, engine_prompt = get_score_prompt(
model_config=model_config,
data_1=data_1,
data_2=data_2,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
self._validate_input(request, engine_prompt["prompt_token_ids"],
full_prompt)
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
return full_prompt, engine_prompt
async def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
data_1: Union[list[str], list[ScoreContentPartParam]],
data_2: Union[list[str], list[ScoreContentPartParam]],
request: Union[RerankRequest, ScoreRequest],
request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
request_prompts: list[str] = []
engine_prompts: list[TokensPrompt] = []
if len(data_1) == 1:
data_1 = data_1 * len(data_2)
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
tokenization_kwargs = tokenization_kwargs or {}
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
preprocess_async = make_async(self._preprocess_score,
executor=self._tokenizer_executor)
preprocessed_prompts = await asyncio.gather(
*(preprocess_async(request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2) for t1, t2 in input_pairs))
for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params()
try:
default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=default_pooling_params,
lora_request=lora_request)
if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = {
"compressed_token_type_ids": compressed
}
else:
pooling_params = (default_pooling_params)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: list[
Optional[PoolingRequestOutput]] = [None] * len(engine_prompts)
async for i, res in result_generator:
final_res_batch[i] = res
return [out for out in final_res_batch if out is not None]
async def _run_scoring(
self,
data_1: Union[list[str], str, ScoreMultiModalParam],
data_2: Union[list[str], str, ScoreMultiModalParam],
request: Union[ScoreRequest, RerankRequest],
request_id: str,
raw_request: Optional[Request] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
tokenization_kwargs)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
if not self.model_config.is_multimodal_model and (isinstance(
data_1, dict) or isinstance(data_2, dict)):
raise ValueError(
f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501
)
if isinstance(data_1, str):
data_1 = [data_1]
elif isinstance(data_1, dict):
data_1 = data_1.get("content") # type: ignore[assignment]
if isinstance(data_2, str):
data_2 = [data_2]
elif isinstance(data_2, dict):
data_2 = data_2.get("content") # type: ignore[assignment]
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
if self.model_config.is_cross_encoder:
return await self._cross_encoding_score(
tokenizer=tokenizer,
data_1=data_1, # type: ignore[arg-type]
data_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers)
else:
return await self._embedding_score(
tokenizer=tokenizer,
texts_1=data_1, # type: ignore[arg-type]
texts_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers)
async def create_score(
self,
request: ScoreRequest,
raw_request: Optional[Request] = None,
) -> Union[ScoreResponse, ErrorResponse]:
"""
Score API similar to Sentence Transformers cross encoder
See https://sbert.net/docs/package_reference/cross_encoder
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.time())
try:
final_res_batch = await self._run_scoring(
request.text_1,
request.text_2,
request,
request_id,
raw_request,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
return self.request_output_to_score_response(
final_res_batch,
request_id,
created_time,
self.models.model_name(),
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def do_rerank(
self,
request: RerankRequest,
raw_request: Optional[Request] = None
) -> Union[RerankResponse, ErrorResponse]:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"rerank-{self._base_request_id(raw_request)}"
documents = request.documents
top_n = request.top_n if request.top_n > 0 else (
len(documents)
if isinstance(documents, list) else len(documents["content"]))
try:
final_res_batch = await self._run_scoring(
request.query,
documents,
request,
request_id,
raw_request,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
return self.request_output_to_rerank_response(
final_res_batch,
request_id,
self.models.model_name(),
documents,
top_n,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def request_output_to_score_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> ScoreResponse:
items: list[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
def request_output_to_rerank_response(
self, final_res_batch: list[PoolingRequestOutput], request_id: str,
model_name: str, documents: Union[list[str], ScoreMultiModalParam],
top_n: int) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
results: list[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
result = RerankResult(
index=idx,
document=RerankDocument(text=documents[idx]) if isinstance(
documents, list) else RerankDocument(
multi_modal=documents["content"][idx]),
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens))

View File

@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Final, Optional, Union
import jinja2
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse,
TokenizerInfoResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_tokenize(
self,
request: TokenizeRequest,
raw_request: Request,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{self._base_request_id(raw_request)}"
try:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest):
tool_dicts = (None if request.tools is None else
[tool.model_dump() for tool in request.tools])
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
tool_dicts=tool_dicts,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
chat_template_kwargs=request.chat_template_kwargs,
add_special_tokens=request.add_special_tokens,
)
else:
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
config=self._build_render_config(request),
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = []
for engine_prompt in engine_prompts:
self._log_inputs(request_id,
engine_prompt,
params=None,
lora_request=lora_request)
if isinstance(engine_prompt,
dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"])
token_strs = None
if request.return_token_strs:
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
return TokenizeResponse(tokens=input_ids,
token_strs=token_strs,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self,
request: DetokenizeRequest,
raw_request: Request,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{self._base_request_id(raw_request)}"
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request)
prompt_input = await self._tokenize_prompt_input_async(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text)
async def get_tokenizer_info(
self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
"""Get comprehensive tokenizer information."""
try:
tokenizer = await self.engine_client.get_tokenizer()
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
return TokenizerInfoResponse(**info)
except Exception as e:
return self.create_error_response(
f"Failed to get tokenizer info: {str(e)}")
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
return RenderConfig(add_special_tokens=request.add_special_tokens)
@dataclass
class TokenizerInfo:
tokenizer: AnyTokenizer
chat_template: Optional[str]
def to_dict(self) -> dict[str, Any]:
"""Return the tokenizer configuration."""
return self._get_tokenizer_config()
def _get_tokenizer_config(self) -> dict[str, Any]:
"""Get tokenizer configuration directly from the tokenizer object."""
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
# Remove file path fields
config.pop("vocab_file", None)
config.pop("merges_file", None)
config = self._make_json_serializable(config)
config["tokenizer_class"] = type(self.tokenizer).__name__
if self.chat_template:
config["chat_template"] = self.chat_template
return config
def _make_json_serializable(self, obj):
"""Convert any non-JSON-serializable objects to serializable format."""
if hasattr(obj, "content"):
return obj.content
elif isinstance(obj, dict):
return {k: self._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
else:
return obj

View File

@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import AsyncGenerator
from typing import Optional, Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponseStreamChoice,
TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
TranslationResponseStreamChoice, TranslationStreamResponse)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
logger = init_logger(__name__)
class OpenAIServingTranscription(OpenAISpeechToText):
"""Handles transcription requests."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe",
log_error_stack=log_error_stack)
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest,
raw_request: Request
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
ErrorResponse]:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranscriptionResponse,
stream_generator_method=self.transcription_stream_generator,
)
async def transcription_stream_generator(
self, request: TranscriptionRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="transcription.chunk",
response_stream_choice_class=TranscriptionResponseStreamChoice,
stream_response_class=TranscriptionStreamResponse,
)
async for chunk in generator:
yield chunk
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate",
log_error_stack=log_error_stack)
async def create_translation(
self, audio_data: bytes, request: TranslationRequest,
raw_request: Request
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranslationResponse,
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self, request: TranslationRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
)
async for chunk in generator:
yield chunk

View File

@@ -0,0 +1,388 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Callable, Literal, Optional, TypeVar, Union, cast
import numpy as np
from fastapi import Request
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
DeltaMessage, ErrorResponse, RequestResponseMetadata,
TranscriptionResponse, TranscriptionResponseStreamChoice,
TranscriptionStreamResponse, TranslationResponse,
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
SpeechToTextRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsTranscription
from vllm.outputs import RequestOutput
from vllm.utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse]
T = TypeVar("T", bound=SpeechToTextResponse)
logger = init_logger(__name__)
class OpenAISpeechToText(OpenAIServing):
"""Base class for speech-to-text operations like transcription and
translation."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
task_type: Literal["transcribe", "translate"] = "transcribe",
log_error_stack: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack)
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
self.task_type = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
model_config, task_type)
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params)
@cached_property
def model_cls(self) -> type[SupportsTranscription]:
from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsTranscription], model_cls)
async def _preprocess_speech_to_text(
self,
request: SpeechToTextRequest,
audio_data: bytes,
) -> tuple[list[PromptType], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
to_language = self.model_cls.validate_language(request.to_language) \
if request.to_language else None
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (self.asr_config.allow_audio_chunking
and duration > self.asr_config.max_audio_clip_s)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
for chunk in chunks:
# The model has control over the construction, as long as it
# returns a valid PromptType.
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
)
prompts.append(prompt)
return prompts, duration
async def _create_speech_to_text(
self,
audio_data: bytes,
request: SpeechToTextRequest,
raw_request: Request,
response_class: type[T],
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
) -> Union[T, AsyncGenerator[str, None], ErrorResponse]:
"""Base method for speech-to-text operations like transcription and
translation."""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
if request.response_format not in ['text', 'json']:
return self.create_error_response(
"Currently only support response_format `text` or `json`")
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
if lora_request:
return self.create_error_response(
"Currently do not support LoRA for "
f"{self.task_type.title()}.")
prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
None]]] = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
self._log_inputs(
request_id,
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=None)
list_result_generator = [
self.engine_client.generate(
prompt,
sampling_params,
request_id,
) for prompt in prompts
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if request.stream:
return stream_generator_method(request, list_result_generator,
request_id, request_metadata,
duration_s)
# Non-streaming response.
try:
assert list_result_generator is not None
text = ""
for result_generator in list_result_generator:
async for op in result_generator:
text += op.outputs[0].text
if self.task_type == "transcribe":
# add usage in TranscriptionResponse.
usage = {
"type": "duration",
# rounded up as per openAI specs
"seconds": int(math.ceil(duration_s)),
}
final_response = cast(T, response_class(text=text,
usage=usage))
else:
# no usage in response for translation task
final_response = cast(
T, response_class(text=text)) # type: ignore[call-arg]
return final_response
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def _speech_to_text_stream_generator(
self,
request: SpeechToTextRequest,
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
response_stream_choice_class: Union[
type[TranscriptionResponseStreamChoice],
type[TranslationResponseStreamChoice]],
stream_response_class: Union[type[TranscriptionStreamResponse],
type[TranslationStreamResponse]],
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
model_name = request.model
completion_tokens = 0
num_prompt_tokens = 0
include_usage = request.stream_include_usage \
if request.stream_include_usage else False
include_continuous_usage = request.stream_continuous_usage_stats\
if include_usage and request.stream_continuous_usage_stats\
else False
try:
for result_generator in list_result_generator:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if audio_tokens := self.model_cls.get_num_audio_tokens(
audio_duration_s, self.asr_config,
self.model_config):
num_prompt_tokens += audio_tokens
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]
delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = response_stream_choice_class(
delta=delta_message)
else:
# Model is finished generating.
choice_data = response_stream_choice_class(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = stream_response_class(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
if include_usage:
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens)
final_usage_chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens)
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
def _split_audio(self, audio_data: np.ndarray,
sample_rate: int) -> list[np.ndarray]:
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
chunks = []
i = 0
while i < audio_data.shape[-1]:
if i + chunk_size >= audio_data.shape[-1]:
# handle last chunk
chunks.append(audio_data[..., i:])
break
# Find the best split point in the overlap region
search_start = i + chunk_size - overlap_size
search_end = min(i + chunk_size, audio_data.shape[-1])
split_point = self._find_split_point(audio_data, search_start,
search_end)
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def _find_split_point(self, wav: np.ndarray, start_idx: int,
end_idx: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
min_energy_window = self.asr_config.min_energy_split_window_size
assert min_energy_window is not None
for i in range(0, len(segment) - min_energy_window, min_energy_window):
window = segment[i:i + min_energy_window]
energy = (window**2).mean()**0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx

View File

@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abstract_tool_parser import ToolParser, ToolParserManager
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
from .hermes_tool_parser import Hermes2ProToolParser
from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .kimi_k2_tool_parser import KimiK2ToolParser
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .longcat_tool_parser import LongcatFlashToolParser
from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .qwen3xml_tool_parser import Qwen3XMLToolParser
from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
__all__ = [
"ToolParser",
"ToolParserManager",
"Granite20bFCToolParser",
"GraniteToolParser",
"Hermes2ProToolParser",
"MistralToolParser",
"Internlm2ToolParser",
"Llama3JsonToolParser",
"JambaToolParser",
"Llama4PythonicToolParser",
"LongcatFlashToolParser",
"PythonicToolParser",
"Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser",
"DeepSeekV31ToolParser",
"xLAMToolParser",
"MinimaxToolParser",
"KimiK2ToolParser",
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
"Qwen3XMLToolParser",
"SeedOssToolParser",
"Step3ToolParser",
"OpenAIToolParser",
]

Some files were not shown because too many files have changed in this diff Show More