first commit
This commit is contained in:
0
vllm/entrypoints/__init__.py
Normal file
0
vllm/entrypoints/__init__.py
Normal file
BIN
vllm/entrypoints/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/api_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/api_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/chat_utils.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/chat_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/constants.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/constants.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/context.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/context.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/harmony_utils.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/harmony_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/launcher.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/launcher.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/llm.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/llm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/logger.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/renderer.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/renderer.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/score_utils.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/score_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/ssl.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/ssl.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/tool.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/tool.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/tool_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/tool_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
178
vllm/entrypoints/api_server.py
Normal file
178
vllm/entrypoints/api_server.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||
and simple performance benchmarks. It is not intended for production use.
|
||||
For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
return await _generate(request_dict, raw_request=request)
|
||||
|
||||
|
||||
@with_cancellation
|
||||
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\n").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
global app
|
||||
|
||||
app.root_path = args.root_path
|
||||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
args: Namespace,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
global engine
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (llm_engine
|
||||
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.API_SERVER))
|
||||
app.state.engine_client = engine
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(args: Namespace,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
**uvicorn_kwargs: Any) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
set_ulimit()
|
||||
|
||||
app = await init_app(args, llm_engine)
|
||||
assert engine is not None
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=None,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=parser.check_port, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The CA certificates file")
|
||||
parser.add_argument(
|
||||
"--enable-ssl-refresh",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Refresh SSL Context when SSL certificate files change")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser.add_argument("--log-level", type=str, default="debug")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
1705
vllm/entrypoints/chat_utils.py
Normal file
1705
vllm/entrypoints/chat_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
12
vllm/entrypoints/cli/__init__.py
Normal file
12
vllm/entrypoints/cli/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
|
||||
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.throughput import (
|
||||
BenchmarkThroughputSubcommand)
|
||||
|
||||
__all__: list[str] = [
|
||||
"BenchmarkLatencySubcommand",
|
||||
"BenchmarkServingSubcommand",
|
||||
"BenchmarkThroughputSubcommand",
|
||||
]
|
||||
BIN
vllm/entrypoints/cli/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/__pycache__/collect_env.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/__pycache__/collect_env.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/__pycache__/main.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/__pycache__/main.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/__pycache__/openai.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/__pycache__/openai.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/__pycache__/run_batch.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/__pycache__/run_batch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/__pycache__/serve.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/__pycache__/serve.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/__pycache__/types.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/__pycache__/types.cpython-310.pyc
Normal file
Binary file not shown.
0
vllm/entrypoints/cli/benchmark/__init__.py
Normal file
0
vllm/entrypoints/cli/benchmark/__init__.py
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/benchmark/__pycache__/base.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/benchmark/__pycache__/base.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/entrypoints/cli/benchmark/__pycache__/main.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/benchmark/__pycache__/main.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/cli/benchmark/__pycache__/serve.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/cli/benchmark/__pycache__/serve.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
25
vllm/entrypoints/cli/benchmark/base.py
Normal file
25
vllm/entrypoints/cli/benchmark/base.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
|
||||
class BenchmarkSubcommandBase(CLISubcommand):
|
||||
""" The base class of subcommands for vllm bench. """
|
||||
|
||||
help: str
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
"""Add the CLI arguments to the parser."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
args: The arguments to the command.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
21
vllm/entrypoints/cli/benchmark/latency.py
Normal file
21
vllm/entrypoints/cli/benchmark/latency.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.latency import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
|
||||
""" The `latency` subcommand for vllm bench. """
|
||||
|
||||
name = "latency"
|
||||
help = "Benchmark the latency of a single batch of requests."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
55
vllm/entrypoints/cli/benchmark/main.py
Normal file
55
vllm/entrypoints/cli/benchmark/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class BenchmarkSubcommand(CLISubcommand):
|
||||
""" The `bench` subcommand for the vLLM CLI. """
|
||||
|
||||
name = "bench"
|
||||
help = "vLLM bench subcommand."
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
args.dispatch_function(args)
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
pass
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
bench_parser = subparsers.add_parser(
|
||||
self.name,
|
||||
description=self.help,
|
||||
usage=f"vllm {self.name} <bench_type> [options]")
|
||||
bench_subparsers = bench_parser.add_subparsers(required=True,
|
||||
dest="bench_type")
|
||||
|
||||
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
|
||||
cmd_subparser = bench_subparsers.add_parser(
|
||||
cmd_cls.name,
|
||||
help=cmd_cls.help,
|
||||
description=cmd_cls.help,
|
||||
usage=f"vllm {self.name} {cmd_cls.name} [options]",
|
||||
)
|
||||
cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd)
|
||||
cmd_cls.add_cli_args(cmd_subparser)
|
||||
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=f"{self.name} {cmd_cls.name}")
|
||||
return bench_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [BenchmarkSubcommand()]
|
||||
21
vllm/entrypoints/cli/benchmark/serve.py
Normal file
21
vllm/entrypoints/cli/benchmark/serve.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.serve import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
|
||||
""" The `serve` subcommand for vllm bench. """
|
||||
|
||||
name = "serve"
|
||||
help = "Benchmark the online serving throughput."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
21
vllm/entrypoints/cli/benchmark/throughput.py
Normal file
21
vllm/entrypoints/cli/benchmark/throughput.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.throughput import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
|
||||
""" The `throughput` subcommand for vllm bench. """
|
||||
|
||||
name = "throughput"
|
||||
help = "Benchmark offline inference throughput."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
36
vllm/entrypoints/cli/collect_env.py
Normal file
36
vllm/entrypoints/cli/collect_env.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
from vllm.collect_env import main as collect_env_main
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class CollectEnvSubcommand(CLISubcommand):
|
||||
"""The `collect-env` subcommand for the vLLM CLI. """
|
||||
name = "collect-env"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Collect information about the environment."""
|
||||
collect_env_main()
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
return subparsers.add_parser(
|
||||
"collect-env",
|
||||
help="Start collecting environment information.",
|
||||
description="Start collecting environment information.",
|
||||
usage="vllm collect-env")
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [CollectEnvSubcommand()]
|
||||
60
vllm/entrypoints/cli/main.py
Normal file
60
vllm/entrypoints/cli/main.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
'''The CLI entrypoints of vLLM
|
||||
|
||||
Note that all future modules must be lazily loaded within main
|
||||
to avoid certain eager import breakage.'''
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
|
||||
def main():
|
||||
import vllm.entrypoints.cli.benchmark.main
|
||||
import vllm.entrypoints.cli.collect_env
|
||||
import vllm.entrypoints.cli.openai
|
||||
import vllm.entrypoints.cli.run_batch
|
||||
import vllm.entrypoints.cli.serve
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
CMD_MODULES = [
|
||||
vllm.entrypoints.cli.openai,
|
||||
vllm.entrypoints.cli.serve,
|
||||
vllm.entrypoints.cli.benchmark.main,
|
||||
vllm.entrypoints.cli.collect_env,
|
||||
vllm.entrypoints.cli.run_batch,
|
||||
]
|
||||
|
||||
cli_env_setup()
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM CLI",
|
||||
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
|
||||
)
|
||||
parser.add_argument(
|
||||
'-v',
|
||||
'--version',
|
||||
action='version',
|
||||
version=importlib.metadata.version('vllm'),
|
||||
)
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
cmds = {}
|
||||
for cmd_module in CMD_MODULES:
|
||||
new_cmds = cmd_module.cmd_init()
|
||||
for cmd in new_cmds:
|
||||
cmd.subparser_init(subparsers).set_defaults(
|
||||
dispatch_function=cmd.cmd)
|
||||
cmds[cmd.name] = cmd
|
||||
args = parser.parse_args()
|
||||
if args.subparser in cmds:
|
||||
cmds[args.subparser].validate(args)
|
||||
|
||||
if hasattr(args, "dispatch_function"):
|
||||
args.dispatch_function(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
233
vllm/entrypoints/cli/openai.py
Normal file
233
vllm/entrypoints/cli/openai.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def _register_signal_handlers():
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTSTP, signal_handler)
|
||||
|
||||
|
||||
def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
|
||||
_register_signal_handlers()
|
||||
|
||||
base_url = args.url
|
||||
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
|
||||
openai_client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
if args.model_name:
|
||||
model_name = args.model_name
|
||||
else:
|
||||
available_models = openai_client.models.list()
|
||||
model_name = available_models.data[0].id
|
||||
|
||||
print(f"Using model: {model_name}")
|
||||
|
||||
return model_name, openai_client
|
||||
|
||||
|
||||
def _print_chat_stream(stream) -> str:
|
||||
output = ""
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
output += delta.content
|
||||
print(delta.content, end="", flush=True)
|
||||
print()
|
||||
return output
|
||||
|
||||
|
||||
def _print_completion_stream(stream) -> str:
|
||||
output = ""
|
||||
for chunk in stream:
|
||||
text = chunk.choices[0].text
|
||||
if text is not None:
|
||||
output += text
|
||||
print(text, end="", flush=True)
|
||||
print()
|
||||
return output
|
||||
|
||||
|
||||
def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
|
||||
conversation: list[ChatCompletionMessageParam] = []
|
||||
if system_prompt is not None:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
while True:
|
||||
try:
|
||||
input_message = input("> ")
|
||||
except EOFError:
|
||||
break
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
stream = client.chat.completions.create(model=model_name,
|
||||
messages=conversation,
|
||||
stream=True)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
|
||||
|
||||
def _add_query_options(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="url of the running OpenAI-Compatible RESTful API server")
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("The model name used in prompt completion, default to "
|
||||
"the first model in list models API call."))
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"API key for OpenAI services. If provided, this api key "
|
||||
"will overwrite the api key obtained through environment variables."
|
||||
))
|
||||
return parser
|
||||
|
||||
|
||||
class ChatCommand(CLISubcommand):
|
||||
"""The `chat` subcommand for the vLLM CLI. """
|
||||
name = "chat"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
model_name, client = _interactive_cli(args)
|
||||
system_prompt = args.system_prompt
|
||||
conversation: list[ChatCompletionMessageParam] = []
|
||||
|
||||
if system_prompt is not None:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
if args.quick:
|
||||
conversation.append({"role": "user", "content": args.quick})
|
||||
|
||||
stream = client.chat.completions.create(model=model_name,
|
||||
messages=conversation,
|
||||
stream=True)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
return
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
while True:
|
||||
try:
|
||||
input_message = input("> ")
|
||||
except EOFError:
|
||||
break
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
stream = client.chat.completions.create(model=model_name,
|
||||
messages=conversation,
|
||||
stream=True)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Add CLI arguments for the chat command."""
|
||||
_add_query_options(parser)
|
||||
parser.add_argument(
|
||||
"--system-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("The system prompt to be added to the chat template, "
|
||||
"used for models that support system prompts."))
|
||||
parser.add_argument("-q",
|
||||
"--quick",
|
||||
type=str,
|
||||
metavar="MESSAGE",
|
||||
help=("Send a single prompt as MESSAGE "
|
||||
"and print the response, then exit."))
|
||||
return parser
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
parser = subparsers.add_parser(
|
||||
"chat",
|
||||
help="Generate chat completions via the running API server.",
|
||||
description="Generate chat completions via the running API server.",
|
||||
usage="vllm chat [options]")
|
||||
return ChatCommand.add_cli_args(parser)
|
||||
|
||||
|
||||
class CompleteCommand(CLISubcommand):
|
||||
"""The `complete` subcommand for the vLLM CLI. """
|
||||
name = 'complete'
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
model_name, client = _interactive_cli(args)
|
||||
|
||||
if args.quick:
|
||||
stream = client.completions.create(model=model_name,
|
||||
prompt=args.quick,
|
||||
stream=True)
|
||||
_print_completion_stream(stream)
|
||||
return
|
||||
|
||||
print("Please enter prompt to complete:")
|
||||
while True:
|
||||
try:
|
||||
input_prompt = input("> ")
|
||||
except EOFError:
|
||||
break
|
||||
stream = client.completions.create(model=model_name,
|
||||
prompt=input_prompt,
|
||||
stream=True)
|
||||
_print_completion_stream(stream)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Add CLI arguments for the complete command."""
|
||||
_add_query_options(parser)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quick",
|
||||
type=str,
|
||||
metavar="PROMPT",
|
||||
help=
|
||||
"Send a single prompt and print the completion output, then exit.")
|
||||
return parser
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
parser = subparsers.add_parser(
|
||||
"complete",
|
||||
help=("Generate text completions based on the given prompt "
|
||||
"via the running API server."),
|
||||
description=("Generate text completions based on the given prompt "
|
||||
"via the running API server."),
|
||||
usage="vllm complete [options]")
|
||||
return CompleteCommand.add_cli_args(parser)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ChatCommand(), CompleteCommand()]
|
||||
67
vllm/entrypoints/cli/run_batch.py
Normal file
67
vllm/entrypoints/cli/run_batch.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import importlib.metadata
|
||||
import typing
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RunBatchSubcommand(CLISubcommand):
|
||||
"""The `run-batch` subcommand for vLLM CLI."""
|
||||
name = "run-batch"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
from vllm.entrypoints.openai.run_batch import main as run_batch_main
|
||||
|
||||
logger.info("vLLM batch processing API version %s",
|
||||
importlib.metadata.version("vllm"))
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server.
|
||||
# LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(run_batch_main(args))
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
from vllm.entrypoints.openai.run_batch import make_arg_parser
|
||||
|
||||
run_batch_parser = subparsers.add_parser(
|
||||
self.name,
|
||||
help="Run batch prompts and write results to file.",
|
||||
description=(
|
||||
"Run batch prompts using vLLM's OpenAI-compatible API.\n"
|
||||
"Supports local or HTTP input/output files."),
|
||||
usage=
|
||||
"vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
|
||||
)
|
||||
run_batch_parser = make_arg_parser(run_batch_parser)
|
||||
run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=self.name)
|
||||
return run_batch_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [RunBatchSubcommand()]
|
||||
232
vllm/entrypoints/cli/serve.py
Normal file
232
vllm/entrypoints/cli/serve.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
from typing import Optional
|
||||
|
||||
import uvloop
|
||||
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
|
||||
setup_server)
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, decorate_logs, get_tcp_uri,
|
||||
set_process_title)
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import (APIServerProcessManager,
|
||||
wait_for_completion_or_failure)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
|
||||
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.
|
||||
|
||||
Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
|
||||
--help=ModelConfig, --help=Frontend)
|
||||
Use `--help=all` to show all available flags at once.
|
||||
"""
|
||||
|
||||
|
||||
class ServeSubcommand(CLISubcommand):
|
||||
"""The `serve` subcommand for the vLLM CLI. """
|
||||
name = "serve"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
# If model is specified in CLI (as positional arg), it takes precedence
|
||||
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||
args.model = args.model_tag
|
||||
|
||||
if args.headless or args.api_server_count < 1:
|
||||
run_headless(args)
|
||||
else:
|
||||
if args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
# Single API server (this process).
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
serve_parser = subparsers.add_parser(
|
||||
self.name,
|
||||
description=DESCRIPTION,
|
||||
usage="vllm serve [model_tag] [options]")
|
||||
|
||||
serve_parser = make_arg_parser(serve_parser)
|
||||
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=self.name)
|
||||
return serve_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ServeSubcommand()]
|
||||
|
||||
|
||||
def run_headless(args: argparse.Namespace):
|
||||
|
||||
if args.api_server_count > 1:
|
||||
raise ValueError("api_server_count can't be set in headless mode")
|
||||
|
||||
# Create the EngineConfig.
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context,
|
||||
headless=True)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("Headless mode is only supported for V1")
|
||||
|
||||
if engine_args.data_parallel_hybrid_lb:
|
||||
raise ValueError("data_parallel_hybrid_lb is not applicable in "
|
||||
"headless mode")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
|
||||
if local_engine_count <= 0:
|
||||
raise ValueError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
logger.debug("Received %d signal.", signum)
|
||||
raise SystemExit
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.", local_engine_count, handshake_address)
|
||||
|
||||
# Create the engines.
|
||||
engine_manager = CoreEngineProcManager(
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=vllm_config.parallel_config.data_parallel_rank,
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
local_client=False,
|
||||
handshake_address=handshake_address,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
|
||||
try:
|
||||
engine_manager.join_first()
|
||||
finally:
|
||||
logger.info("Shutting down.")
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
assert not args.headless
|
||||
num_api_servers: int = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
engine_args._api_process_count = num_api_servers
|
||||
engine_args._api_process_rank = -1
|
||||
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
if num_api_servers > 1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("api_server_count > 1 is only supported for V1")
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||
"with api_server_count > 1")
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
log_stats = not engine_args.disable_log_stats
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
hybrid_dp_lb = parallel_config.data_parallel_hybrid_lb
|
||||
assert external_dp_lb or hybrid_dp_lb or dp_rank == 0
|
||||
|
||||
api_server_manager: Optional[APIServerProcessManager] = None
|
||||
|
||||
with launch_core_engines(vllm_config, executor_class, log_stats,
|
||||
num_api_servers) as (local_engine_manager,
|
||||
coordinator, addresses):
|
||||
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=addresses.inputs,
|
||||
output_addresses=addresses.outputs,
|
||||
stats_update_address=coordinator.get_stats_publish_address()
|
||||
if coordinator else None)
|
||||
|
||||
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
|
||||
# start of the API servers until the local engine is started
|
||||
# (after the launcher context manager exits),
|
||||
# since we get the front-end stats update address from the coordinator
|
||||
# via the handshake with the local engine.
|
||||
if dp_rank == 0 or not (external_dp_lb or hybrid_dp_lb):
|
||||
# Start API servers using the manager.
|
||||
api_server_manager = APIServerProcessManager(
|
||||
**api_server_manager_kwargs)
|
||||
|
||||
# Start API servers now if they weren't already started.
|
||||
if api_server_manager is None:
|
||||
api_server_manager_kwargs["stats_update_address"] = (
|
||||
addresses.frontend_stats_publish_address)
|
||||
api_server_manager = APIServerProcessManager(
|
||||
**api_server_manager_kwargs)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(listen_address,
|
||||
sock,
|
||||
args,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
client_config = client_config or {}
|
||||
server_index = client_config.get("client_index", 0)
|
||||
|
||||
# Set process title and add process-specific prefix to stdout and stderr.
|
||||
set_process_title("APIServer", str(server_index))
|
||||
decorate_logs()
|
||||
|
||||
uvloop.run(
|
||||
run_server_worker(listen_address, sock, args, client_config,
|
||||
**uvicorn_kwargs))
|
||||
29
vllm/entrypoints/cli/types.py
Normal file
29
vllm/entrypoints/cli/types.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class CLISubcommand:
|
||||
"""Base class for CLI argument handlers."""
|
||||
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
raise NotImplementedError("Subclasses should implement this method")
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
# No validation by default
|
||||
pass
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
raise NotImplementedError("Subclasses should implement this method")
|
||||
10
vllm/entrypoints/constants.py
Normal file
10
vllm/entrypoints/constants.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shared constants for vLLM entrypoints.
|
||||
"""
|
||||
|
||||
# HTTP header limits for h11 parser
|
||||
# These constants help mitigate header abuse attacks
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
|
||||
H11_MAX_HEADER_COUNT_DEFAULT = 256
|
||||
481
vllm/entrypoints/context.py
Normal file
481
vllm/entrypoints/context.py
Normal file
@@ -0,0 +1,481 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This is currently needed as the tool type doesn't 1:1 match the
|
||||
# tool namespace, which is what is used to look up the
|
||||
# connection to the tool server
|
||||
_TOOL_NAME_TO_TYPE_MAP = {
|
||||
"browser": "web_search_preview",
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}")
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnTokens:
|
||||
"""Tracks token counts for a single conversation turn."""
|
||||
|
||||
def __init__(self, input_tokens=0, output_tokens=0):
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
|
||||
def reset(self):
|
||||
"""Reset counters for a new turn."""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of this turn's token counts."""
|
||||
return TurnTokens(self.input_tokens, self.output_tokens)
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(self) -> list[Message]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_for_completion(self) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
# todo num_reasoning_tokens is not implemented yet.
|
||||
self.num_reasoning_tokens = 0
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
if not isinstance(output, RequestOutput):
|
||||
raise ValueError("SimpleContext only supports RequestOutput.")
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
available_tools: list[str],
|
||||
):
|
||||
self._messages = messages
|
||||
self.finish_reason: Optional[str] = None
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
self.num_tool_output_tokens = 0
|
||||
|
||||
# Turn tracking - replaces multiple individual tracking variables
|
||||
self.current_turn = TurnTokens()
|
||||
self.previous_turn = TurnTokens()
|
||||
self.is_first_turn = True
|
||||
self.first_tok_of_message = True # For streaming support
|
||||
|
||||
def _update_num_reasoning_tokens(self):
|
||||
# Count all analysis and commentary channels as reasoning tokens
|
||||
if self.parser.current_channel in {"analysis", "commentary"}:
|
||||
self.num_reasoning_tokens += 1
|
||||
|
||||
def append_output(self, output: Union[RequestOutput,
|
||||
list[Message]]) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self._update_prefill_token_usage(output)
|
||||
# Reset current turn output tokens for this turn
|
||||
self.current_turn.output_tokens = 0
|
||||
self._update_decode_token_usage(output)
|
||||
# Move current turn to previous turn for next turn's calculations
|
||||
self.previous_turn = self.current_turn.copy()
|
||||
# append_output is called only once before tool calling
|
||||
# in non-streaming case
|
||||
# so we can append all the parser messages to _messages
|
||||
output_msgs = self.parser.messages
|
||||
# The responses finish reason is set in the last message
|
||||
self.finish_reason = output.outputs[0].finish_reason
|
||||
else:
|
||||
# Tool output.
|
||||
output_msgs = output
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
||||
"""Update token usage statistics for the prefill phase of generation.
|
||||
|
||||
The prefill phase processes the input prompt tokens. This method:
|
||||
1. Counts the prompt tokens for this turn
|
||||
2. Calculates tool output tokens for multi-turn conversations
|
||||
3. Updates cached token counts
|
||||
4. Tracks state for next turn calculations
|
||||
|
||||
Tool output tokens are calculated as:
|
||||
current_prompt_tokens - last_turn_prompt_tokens -
|
||||
last_turn_output_tokens
|
||||
This represents tokens added between turns (typically tool responses).
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing prompt token information
|
||||
"""
|
||||
if output.prompt_token_ids is not None:
|
||||
this_turn_input_tokens = len(output.prompt_token_ids)
|
||||
else:
|
||||
this_turn_input_tokens = 0
|
||||
logger.error(
|
||||
"RequestOutput appended contains no prompt_token_ids.")
|
||||
|
||||
# Update current turn input tokens
|
||||
self.current_turn.input_tokens = this_turn_input_tokens
|
||||
self.num_prompt_tokens += this_turn_input_tokens
|
||||
|
||||
# Calculate tool tokens (except on first turn)
|
||||
if self.is_first_turn:
|
||||
self.is_first_turn = False
|
||||
else:
|
||||
# start counting tool after first turn
|
||||
# tool tokens = this turn prefill - last turn prefill -
|
||||
# last turn decode
|
||||
this_turn_tool_tokens = (self.current_turn.input_tokens -
|
||||
self.previous_turn.input_tokens -
|
||||
self.previous_turn.output_tokens)
|
||||
|
||||
# Handle negative tool token counts (shouldn't happen in normal
|
||||
# cases)
|
||||
if this_turn_tool_tokens < 0:
|
||||
logger.error(
|
||||
"Negative tool output tokens calculated: %d "
|
||||
"(current_input=%d, previous_input=%d, "
|
||||
"previous_output=%d). Setting to 0.",
|
||||
this_turn_tool_tokens, self.current_turn.input_tokens,
|
||||
self.previous_turn.input_tokens,
|
||||
self.previous_turn.output_tokens)
|
||||
this_turn_tool_tokens = 0
|
||||
|
||||
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||
|
||||
# Update cached tokens
|
||||
if output.num_cached_tokens is not None:
|
||||
self.num_cached_tokens += output.num_cached_tokens
|
||||
|
||||
def _update_decode_token_usage(self, output: RequestOutput) -> int:
|
||||
"""Update token usage statistics for the decode phase of generation.
|
||||
|
||||
The decode phase processes the generated output tokens. This method:
|
||||
1. Counts output tokens from all completion outputs
|
||||
2. Updates the total output token count
|
||||
3. Tracks tokens generated in the current turn
|
||||
|
||||
In streaming mode, this is called for each token generated.
|
||||
In non-streaming mode, this is called once with all output tokens.
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing generated token information
|
||||
|
||||
Returns:
|
||||
int: Number of output tokens processed in this call
|
||||
"""
|
||||
updated_output_token_count = 0
|
||||
if output.outputs:
|
||||
for completion_output in output.outputs:
|
||||
# only keep last round
|
||||
updated_output_token_count += len(completion_output.token_ids)
|
||||
self.num_output_tokens += updated_output_token_count
|
||||
self.current_turn.output_tokens += updated_output_token_count
|
||||
return updated_output_token_count
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (recipient.startswith("browser.")
|
||||
or recipient.startswith("python") or
|
||||
recipient.startswith("container."))
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
return []
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self._tool_sessions["browser"], last_msg)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
return render_for_completion(self.messages)
|
||||
|
||||
async def call_search_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
]
|
||||
|
||||
async def call_python_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
"code": last_msg.content[0].text,
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name="python")
|
||||
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT)
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = mcp_tools[
|
||||
tool_type].headers if tool_type in mcp_tools else None
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id,
|
||||
headers))
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info("Cleaning up tool session for %s",
|
||||
tool_session._client_info)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools))
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_output = None
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.encoding = get_encoding()
|
||||
self.last_tok = None
|
||||
self.first_tok_of_message = True
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def append_output(self, output: Union[RequestOutput,
|
||||
list[Message]]) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
if self.first_tok_of_message:
|
||||
self._update_prefill_token_usage(output)
|
||||
self.current_turn.output_tokens = 0
|
||||
# Reset self.first_tok_of_message if needed:
|
||||
# if the current token is the last one of the current message
|
||||
# (finished=True), then the next token processed will mark the
|
||||
# beginning of a new message
|
||||
self.first_tok_of_message = output.finished
|
||||
for tok in output.outputs[0].token_ids:
|
||||
self.parser.process(tok)
|
||||
self._update_decode_token_usage(output)
|
||||
|
||||
# For streaming, update previous turn when message is complete
|
||||
if output.finished:
|
||||
self.previous_turn = self.current_turn.copy()
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
if len(self._messages) - self.num_init_messages < len(
|
||||
self.parser.messages):
|
||||
self._messages.extend(
|
||||
self.parser.messages[len(self._messages) -
|
||||
self.num_init_messages:])
|
||||
else:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
msg = output[0]
|
||||
# Sometimes the recipient is not set for tool messages,
|
||||
# so we set it to "assistant"
|
||||
if msg.author.role == Role.TOOL and msg.recipient is None:
|
||||
msg.recipient = "assistant"
|
||||
toks = self.encoding.render(msg)
|
||||
for tok in toks:
|
||||
self.parser.process(tok)
|
||||
self.last_tok = toks[-1]
|
||||
# TODO: add tool_output messages to self._messages
|
||||
|
||||
def is_expecting_start(self) -> bool:
|
||||
return self.parser.state == StreamState.EXPECT_START
|
||||
|
||||
def is_assistant_action_turn(self) -> bool:
|
||||
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions(
|
||||
)
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
# now this list of tokens as next turn's starting tokens
|
||||
# `<|start|>assistant``,
|
||||
# we need to process them in parser.
|
||||
rendered_tokens = super().render_for_completion()
|
||||
|
||||
last_n = -1
|
||||
to_process = []
|
||||
while rendered_tokens[last_n] != self.last_tok:
|
||||
to_process.append(rendered_tokens[last_n])
|
||||
last_n -= 1
|
||||
for tok in reversed(to_process):
|
||||
self.parser.process(tok)
|
||||
|
||||
return rendered_tokens
|
||||
436
vllm/entrypoints/harmony_utils.py
Normal file
436
vllm/entrypoints/harmony_utils.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from openai.types.responses import (ResponseFunctionToolCall,
|
||||
ResponseOutputItem, ResponseOutputMessage,
|
||||
ResponseOutputText, ResponseReasoningItem)
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch)
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent)
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import (Author, ChannelConfig, Conversation,
|
||||
DeveloperContent, HarmonyEncodingName, Message,
|
||||
ReasoningEffort, Role, StreamableParser,
|
||||
SystemContent, TextContent, ToolDescription,
|
||||
load_harmony_encoding)
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||
ResponseInputOutputItem)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
REASONING_EFFORT = {
|
||||
"high": ReasoningEffort.HIGH,
|
||||
"medium": ReasoningEffort.MEDIUM,
|
||||
"low": ReasoningEffort.LOW,
|
||||
}
|
||||
|
||||
_harmony_encoding = None
|
||||
|
||||
# Builtin tools that should be included in the system message when
|
||||
# they are available and requested by the user.
|
||||
# Tool args are provided by MCP tool descriptions. Output
|
||||
# of the tools are stringified.
|
||||
BUILTIN_TOOLS = {
|
||||
"web_search_preview",
|
||||
"code_interpreter",
|
||||
"container",
|
||||
}
|
||||
|
||||
|
||||
def has_custom_tools(tool_types: list[str]) -> bool:
|
||||
return not set(tool_types).issubset(BUILTIN_TOOLS)
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _harmony_encoding
|
||||
if _harmony_encoding is None:
|
||||
_harmony_encoding = load_harmony_encoding(
|
||||
HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return _harmony_encoding
|
||||
|
||||
|
||||
def get_system_message(
|
||||
model_identity: Optional[str] = None,
|
||||
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
|
||||
start_date: Optional[str] = None,
|
||||
browser_description: Optional[str] = None,
|
||||
python_description: Optional[str] = None,
|
||||
container_description: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
with_custom_tools: bool = False,
|
||||
) -> Message:
|
||||
sys_msg_content = SystemContent.new()
|
||||
if model_identity is not None:
|
||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||
if (instructions is not None
|
||||
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
||||
current_identity = sys_msg_content.model_identity
|
||||
new_identity = (f'{current_identity}\n{instructions}'
|
||||
if current_identity else instructions)
|
||||
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
|
||||
if reasoning_effort is not None:
|
||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||
REASONING_EFFORT[reasoning_effort])
|
||||
if start_date is None:
|
||||
# NOTE(woosuk): This brings non-determinism in vLLM. Be careful.
|
||||
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
|
||||
if browser_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
if python_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||
if container_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(container_description)
|
||||
if not with_custom_tools:
|
||||
channel_config = sys_msg_content.channel_config
|
||||
invalid_channel = "commentary"
|
||||
new_config = ChannelConfig.require_channels(
|
||||
[c for c in channel_config.valid_channels if c != invalid_channel])
|
||||
sys_msg_content = sys_msg_content.with_channel_config(new_config)
|
||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||
return sys_msg
|
||||
|
||||
|
||||
def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
|
||||
if isinstance(tool, ChatCompletionToolsParam):
|
||||
return ToolDescription.new(
|
||||
name=tool.function.name,
|
||||
description=tool.function.description,
|
||||
parameters=tool.function.parameters,
|
||||
)
|
||||
return ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
)
|
||||
|
||||
|
||||
def get_developer_message(
|
||||
instructions: Optional[str] = None,
|
||||
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
||||
) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if (instructions is not None
|
||||
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter",
|
||||
"container", "mcp"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
# Adding in MCP for now until we support MCP tools executed
|
||||
# server side
|
||||
pass
|
||||
|
||||
elif tool.type == "function":
|
||||
function_tools.append(tool)
|
||||
else:
|
||||
raise ValueError(f"tool type {tool.type} not supported")
|
||||
if function_tools:
|
||||
function_tool_descriptions = [
|
||||
create_tool_definition(tool) for tool in function_tools
|
||||
]
|
||||
dev_msg_content = dev_msg_content.with_function_tools(
|
||||
function_tool_descriptions)
|
||||
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
||||
return dev_msg
|
||||
|
||||
|
||||
def get_user_message(content: str) -> Message:
|
||||
return Message.from_role_and_content(Role.USER, content)
|
||||
|
||||
|
||||
def parse_response_input(
|
||||
response_msg: ResponseInputOutputItem,
|
||||
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]]
|
||||
) -> Message:
|
||||
if not isinstance(response_msg, dict):
|
||||
response_msg = response_msg.model_dump()
|
||||
if "type" not in response_msg or response_msg["type"] == "message":
|
||||
role = response_msg["role"]
|
||||
content = response_msg["content"]
|
||||
if role == "system":
|
||||
# User is trying to set a system message. Change it to:
|
||||
# <|start|>developer<|message|># Instructions
|
||||
# {instructions}<|end|>
|
||||
role = "developer"
|
||||
text_prefix = "Instructions:\n"
|
||||
else:
|
||||
text_prefix = ""
|
||||
if isinstance(content, str):
|
||||
msg = Message.from_role_and_content(role, text_prefix + content)
|
||||
else:
|
||||
contents = [
|
||||
TextContent(text=text_prefix + c["text"]) for c in content
|
||||
]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
if role == "assistant":
|
||||
msg = msg.with_channel("final")
|
||||
elif response_msg["type"] == "function_call_output":
|
||||
call_id = response_msg["call_id"]
|
||||
call_response: Optional[ResponseFunctionToolCall] = None
|
||||
for prev_response in reversed(prev_responses):
|
||||
if isinstance(prev_response, ResponseFunctionToolCall
|
||||
) and prev_response.call_id == call_id:
|
||||
call_response = prev_response
|
||||
break
|
||||
if call_response is None:
|
||||
raise ValueError(f"No call message found for {call_id}")
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
||||
response_msg["output"])
|
||||
elif response_msg["type"] == "reasoning":
|
||||
content = response_msg["content"]
|
||||
assert len(content) == 1
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
||||
elif response_msg["type"] == "function_call":
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT,
|
||||
response_msg["arguments"])
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{response_msg['name']}")
|
||||
msg = msg.with_content_type("json")
|
||||
else:
|
||||
raise ValueError(f"Unknown input type: {response_msg['type']}")
|
||||
return msg
|
||||
|
||||
|
||||
def parse_chat_input(chat_msg) -> list[Message]:
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
role = chat_msg.get("role")
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls")
|
||||
if role == "assistant" and tool_calls:
|
||||
msgs: list[Message] = []
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"),
|
||||
content).with_channel("commentary")
|
||||
return [msg]
|
||||
|
||||
# Default: user/assistant/system messages with content
|
||||
content = chat_msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
return [msg]
|
||||
|
||||
|
||||
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||
conversation = Conversation.from_messages(messages)
|
||||
token_ids = get_encoding().render_conversation_for_completion(
|
||||
conversation, Role.ASSISTANT)
|
||||
return token_ids
|
||||
|
||||
|
||||
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||
"""
|
||||
Parse a Harmony message into a list of output response items.
|
||||
"""
|
||||
if message.author.role != "assistant":
|
||||
# This is a message from a tool to the assistant (e.g., search result).
|
||||
# Don't include it in the final output for now. This aligns with
|
||||
# OpenAI's behavior on models like o4-mini.
|
||||
return []
|
||||
|
||||
output_items: list[ResponseOutputItem] = []
|
||||
recipient = message.recipient
|
||||
if recipient is not None and recipient.startswith("browser."):
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
browser_call = json.loads(content.text)
|
||||
# TODO: translate to url properly!
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search")
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(
|
||||
url=f"cursor:{browser_call.get('url', '')}", type="open_page")
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(pattern=browser_call["pattern"],
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find")
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
web_search_item = ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
output_items.append(web_search_item)
|
||||
elif message.channel == "analysis":
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(text=content.text,
|
||||
type="reasoning_text")
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
elif message.channel == "commentary":
|
||||
if recipient is not None and recipient.startswith("functions."):
|
||||
function_name = recipient.split(".")[-1]
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"fc_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
elif recipient is not None and (recipient.startswith("python")
|
||||
or recipient.startswith("browser")
|
||||
or recipient.startswith("container")):
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(text=content.text,
|
||||
type="reasoning_text")
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown recipient: {recipient}")
|
||||
elif message.channel == "final":
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
output_items.append(text_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown channel: {message.channel}")
|
||||
return output_items
|
||||
|
||||
|
||||
def parse_remaining_state(
|
||||
parser: StreamableParser) -> list[ResponseOutputItem]:
|
||||
if not parser.current_content:
|
||||
return []
|
||||
if parser.current_role != Role.ASSISTANT:
|
||||
return []
|
||||
current_recipient = parser.current_recipient
|
||||
if (current_recipient is not None
|
||||
and current_recipient.startswith("browser.")):
|
||||
return []
|
||||
|
||||
if parser.current_channel == "analysis":
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(text=parser.current_content,
|
||||
type="reasoning_text")
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
return [reasoning_item]
|
||||
elif parser.current_channel == "final":
|
||||
output_text = ResponseOutputText(
|
||||
text=parser.current_content,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
role="assistant",
|
||||
# if the parser still has messages (ie if the generator got cut
|
||||
# abruptly), this should be incomplete
|
||||
status="incomplete",
|
||||
type="message",
|
||||
)
|
||||
return [text_item]
|
||||
return []
|
||||
|
||||
|
||||
def get_stop_tokens_for_assistant_actions() -> list[int]:
|
||||
return get_encoding().stop_tokens_for_assistant_actions()
|
||||
|
||||
|
||||
def get_streamable_parser_for_assistant() -> StreamableParser:
|
||||
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
||||
|
||||
|
||||
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
||||
parser = get_streamable_parser_for_assistant()
|
||||
for token_id in token_ids:
|
||||
parser.process(token_id)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_chat_output(
|
||||
token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]:
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
output_msgs = parser.messages
|
||||
is_tool_call = False # TODO: update this when tool call is supported
|
||||
if len(output_msgs) == 0:
|
||||
# The generation has stopped during reasoning.
|
||||
reasoning_content = parser.current_content
|
||||
final_content = None
|
||||
elif len(output_msgs) == 1:
|
||||
# The generation has stopped during final message.
|
||||
reasoning_content = output_msgs[0].content[0].text
|
||||
final_content = parser.current_content
|
||||
else:
|
||||
reasoning_msg = output_msgs[:-1]
|
||||
final_msg = output_msgs[-1]
|
||||
reasoning_content = "\n".join(
|
||||
[msg.content[0].text for msg in reasoning_msg])
|
||||
final_content = final_msg.content[0].text
|
||||
return reasoning_content, final_content, is_tool_call
|
||||
164
vllm/entrypoints/launcher.py
Normal file
164
vllm/entrypoints/launcher.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import socket
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(app: FastAPI,
|
||||
sock: Optional[socket.socket],
|
||||
enable_ssl_refresh: bool = False,
|
||||
**uvicorn_kwargs: Any):
|
||||
"""
|
||||
Start a FastAPI app using Uvicorn, with support for custom Uvicorn config
|
||||
options. Supports http header limits via h11_max_incomplete_event_size and
|
||||
h11_max_header_count.
|
||||
"""
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
methods = getattr(route, "methods", None)
|
||||
path = getattr(route, "path", None)
|
||||
|
||||
if methods is None or path is None:
|
||||
continue
|
||||
|
||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||
|
||||
# Extract header limit options if present
|
||||
h11_max_incomplete_event_size = uvicorn_kwargs.pop(
|
||||
"h11_max_incomplete_event_size", None)
|
||||
h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None)
|
||||
|
||||
# Set safe defaults if not provided
|
||||
if h11_max_incomplete_event_size is None:
|
||||
h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||
if h11_max_header_count is None:
|
||||
h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
# Set header limits
|
||||
config.h11_max_incomplete_event_size = h11_max_incomplete_event_size
|
||||
config.h11_max_header_count = h11_max_header_count
|
||||
config.load()
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
watchdog_task = loop.create_task(
|
||||
watchdog_loop(server, app.state.engine_client))
|
||||
server_task = loop.create_task(
|
||||
server.serve(sockets=[sock] if sock else None))
|
||||
|
||||
ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
|
||||
ssl_context=config.ssl,
|
||||
key_path=config.ssl_keyfile,
|
||||
cert_path=config.ssl_certfile,
|
||||
ca_path=config.ssl_ca_certs)
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
watchdog_task.cancel()
|
||||
if ssl_cert_refresher:
|
||||
ssl_cert_refresher.stop()
|
||||
|
||||
async def dummy_shutdown() -> None:
|
||||
pass
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
return dummy_shutdown()
|
||||
except asyncio.CancelledError:
|
||||
port = uvicorn_kwargs["port"]
|
||||
process = find_process_using_port(port)
|
||||
if process is not None:
|
||||
logger.warning(
|
||||
"port %s is used by process %s launched with command:\n%s",
|
||||
port, process, " ".join(process.cmdline()))
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
finally:
|
||||
watchdog_task.cancel()
|
||||
|
||||
|
||||
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
|
||||
"""
|
||||
# Watchdog task that runs in the background, checking
|
||||
# for error state in the engine. Needed to trigger shutdown
|
||||
# if an exception arises is StreamingResponse() generator.
|
||||
"""
|
||||
VLLM_WATCHDOG_TIME_S = 5.0
|
||||
while True:
|
||||
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
|
||||
terminate_if_errored(server, engine)
|
||||
|
||||
|
||||
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
|
||||
"""
|
||||
See discussions here on shutting down a uvicorn server
|
||||
https://github.com/encode/uvicorn/discussions/1103
|
||||
In this case we cannot await the server shutdown here
|
||||
because handler must first return to close the connection
|
||||
for this request.
|
||||
"""
|
||||
engine_errored = engine.errored and not engine.is_running
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
|
||||
server.should_exit = True
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""
|
||||
VLLM V1 AsyncLLM catches exceptions and returns
|
||||
only two types: EngineGenerateError and EngineDeadError.
|
||||
|
||||
EngineGenerateError is raised by the per request generate()
|
||||
method. This error could be request specific (and therefore
|
||||
recoverable - e.g. if there is an error in input processing).
|
||||
|
||||
EngineDeadError is raised by the background output_handler
|
||||
method. This error is global and therefore not recoverable.
|
||||
|
||||
We register these @app.exception_handlers to return nice
|
||||
responses to the end user if they occur and shut down if needed.
|
||||
See https://fastapi.tiangolo.com/tutorial/handling-errors/
|
||||
for more details on how exception handlers work.
|
||||
|
||||
If an exception is encountered in a StreamingResponse
|
||||
generator, the exception is not raised, since we already sent
|
||||
a 200 status. Rather, we send an error message as the next chunk.
|
||||
Since the exception is not raised, this means that the server
|
||||
will not automatically shut down. Instead, we use the watchdog
|
||||
background task for check for errored state.
|
||||
"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
@app.exception_handler(EngineDeadError)
|
||||
@app.exception_handler(EngineGenerateError)
|
||||
async def runtime_exception_handler(request: Request, __):
|
||||
terminate_if_errored(
|
||||
server=server,
|
||||
engine=request.app.state.engine_client,
|
||||
)
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
1629
vllm/entrypoints/llm.py
Normal file
1629
vllm/entrypoints/llm.py
Normal file
File diff suppressed because it is too large
Load Diff
79
vllm/entrypoints/logger.py
Normal file
79
vllm/entrypoints/logger.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RequestLogger:
|
||||
|
||||
def __init__(self, *, max_log_len: Optional[int]) -> None:
|
||||
self.max_log_len = max_log_len
|
||||
|
||||
def log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
prompt_embeds: Optional[torch.Tensor],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if prompt is not None:
|
||||
prompt = prompt[:max_log_len]
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = prompt_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"prompt_embeds shape: %s, "
|
||||
"lora_request: %s.", request_id, prompt, params, prompt_token_ids,
|
||||
prompt_embeds.shape if prompt_embeds is not None else None,
|
||||
lora_request)
|
||||
|
||||
def log_outputs(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: str,
|
||||
output_token_ids: Optional[Sequence[int]],
|
||||
finish_reason: Optional[str] = None,
|
||||
is_streaming: bool = False,
|
||||
delta: bool = False,
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if outputs is not None:
|
||||
outputs = outputs[:max_log_len]
|
||||
|
||||
if output_token_ids is not None:
|
||||
# Convert to list and apply truncation
|
||||
output_token_ids = list(output_token_ids)[:max_log_len]
|
||||
|
||||
stream_info = ""
|
||||
if is_streaming:
|
||||
stream_info = (" (streaming delta)"
|
||||
if delta else " (streaming complete)")
|
||||
|
||||
logger.info(
|
||||
"Generated response %s%s: output: %r, "
|
||||
"output_token_ids: %s, finish_reason: %s",
|
||||
request_id,
|
||||
stream_info,
|
||||
outputs,
|
||||
output_token_ids,
|
||||
finish_reason,
|
||||
)
|
||||
0
vllm/entrypoints/openai/__init__.py
Normal file
0
vllm/entrypoints/openai/__init__.py
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1953
vllm/entrypoints/openai/api_server.py
Normal file
1953
vllm/entrypoints/openai/api_server.py
Normal file
File diff suppressed because it is too large
Load Diff
288
vllm/entrypoints/openai/cli_args.py
Normal file
288
vllm/entrypoints/openai/cli_args.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import field
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import config
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: list[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, ""]: # Skip if item is None or empty string
|
||||
continue
|
||||
if "=" in item and "," not in item: # Old format: name=path
|
||||
name, path = item.split("=")
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FrontendArgs:
|
||||
"""Arguments for the OpenAI-compatible frontend server."""
|
||||
host: Optional[str] = None
|
||||
"""Host name."""
|
||||
port: int = 8000
|
||||
"""Port number."""
|
||||
uds: Optional[str] = None
|
||||
"""Unix domain socket path. If set, host and port arguments are ignored."""
|
||||
uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical",
|
||||
"trace"] = "info"
|
||||
"""Log level for uvicorn."""
|
||||
disable_uvicorn_access_log: bool = False
|
||||
"""Disable uvicorn access log."""
|
||||
allow_credentials: bool = False
|
||||
"""Allow credentials."""
|
||||
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed origins."""
|
||||
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed methods."""
|
||||
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed headers."""
|
||||
api_key: Optional[list[str]] = None
|
||||
"""If provided, the server will require one of these keys to be presented in
|
||||
the header."""
|
||||
lora_modules: Optional[list[LoRAModulePath]] = None
|
||||
"""LoRA modules configurations in either 'name=path' format or JSON format
|
||||
or JSON list format. Example (old format): `'name=path'` Example (new
|
||||
format): `{\"name\": \"name\", \"path\": \"lora_path\",
|
||||
\"base_model_name\": \"id\"}`"""
|
||||
chat_template: Optional[str] = None
|
||||
"""The file path to the chat template, or the template in single-line form
|
||||
for the specified model."""
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format to render message content within a chat template.
|
||||
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to
|
||||
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
trust_request_chat_template: bool = False
|
||||
"""Whether to trust the chat template provided in the request. If False,
|
||||
the server will always use the chat template specified by `--chat-template`
|
||||
or the ones from tokenizer."""
|
||||
response_role: str = "assistant"
|
||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||
ssl_keyfile: Optional[str] = None
|
||||
"""The file path to the SSL key file."""
|
||||
ssl_certfile: Optional[str] = None
|
||||
"""The file path to the SSL cert file."""
|
||||
ssl_ca_certs: Optional[str] = None
|
||||
"""The CA certificates file."""
|
||||
enable_ssl_refresh: bool = False
|
||||
"""Refresh SSL Context when SSL certificate files change"""
|
||||
ssl_cert_reqs: int = int(ssl.CERT_NONE)
|
||||
"""Whether client certificate is required (see stdlib ssl module's)."""
|
||||
root_path: Optional[str] = None
|
||||
"""FastAPI root_path when app is behind a path based routing proxy."""
|
||||
middleware: list[str] = field(default_factory=lambda: [])
|
||||
"""Additional ASGI middleware to apply to the app. We accept multiple
|
||||
--middleware arguments. The value should be an import path. If a function
|
||||
is provided, vLLM will add it to the server using
|
||||
`@app.middleware('http')`. If a class is provided, vLLM will
|
||||
add it to the server using `app.add_middleware()`."""
|
||||
return_tokens_as_token_ids: bool = False
|
||||
"""When `--max-logprobs` is specified, represents single tokens as
|
||||
strings of the form 'token_id:{token_id}' so that tokens that are not
|
||||
JSON-encodable can be identified."""
|
||||
disable_frontend_multiprocessing: bool = False
|
||||
"""If specified, will run the OpenAI frontend server in the same process as
|
||||
the model serving engine."""
|
||||
enable_request_id_headers: bool = False
|
||||
"""If specified, API server will add X-Request-Id header to responses."""
|
||||
enable_auto_tool_choice: bool = False
|
||||
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
|
||||
to specify which parser to use."""
|
||||
exclude_tools_when_tool_choice_none: bool = False
|
||||
"""If specified, exclude tool definitions in prompts when
|
||||
tool_choice='none'."""
|
||||
tool_call_parser: Optional[str] = None
|
||||
"""Select the tool call parser depending on the model that you're using.
|
||||
This is used to parse the model-generated tool call into OpenAI API format.
|
||||
Required for `--enable-auto-tool-choice`. You can choose any option from
|
||||
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
|
||||
tool_parser_plugin: str = ""
|
||||
"""Special the tool parser plugin write to parse the model-generated tool
|
||||
into OpenAI API format, the name register in this plugin can be used in
|
||||
`--tool-call-parser`."""
|
||||
tool_server: Optional[str] = None
|
||||
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
|
||||
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
|
||||
purpose."""
|
||||
log_config_file: Optional[str] = envs.VLLM_LOGGING_CONFIG_PATH
|
||||
"""Path to logging config JSON file for both vllm and uvicorn"""
|
||||
max_log_len: Optional[int] = None
|
||||
"""Max number of prompt characters or prompt ID numbers being printed in
|
||||
log. The default of None means unlimited."""
|
||||
disable_fastapi_docs: bool = False
|
||||
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
|
||||
enable_prompt_tokens_details: bool = False
|
||||
"""If set to True, enable prompt_tokens_details in usage."""
|
||||
enable_server_load_tracking: bool = False
|
||||
"""If set to True, enable tracking server_load_metrics in the app state."""
|
||||
enable_force_include_usage: bool = False
|
||||
"""If set to True, including usage on every request."""
|
||||
enable_tokenizer_info_endpoint: bool = False
|
||||
"""Enable the /get_tokenizer_info endpoint. May expose chat
|
||||
templates and other tokenizer configuration."""
|
||||
enable_log_outputs: bool = False
|
||||
"""If True, log model outputs (generations).
|
||||
Requires --enable-log-requests."""
|
||||
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
|
||||
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
|
||||
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
|
||||
"""Maximum number of HTTP headers allowed in a request for h11 parser.
|
||||
Helps mitigate header abuse. Default: 256."""
|
||||
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
|
||||
"""If set to True, log the stack trace of error responses"""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
from vllm.engine.arg_utils import get_kwargs
|
||||
|
||||
frontend_kwargs = get_kwargs(FrontendArgs)
|
||||
|
||||
# Special case: allowed_origins, allowed_methods, allowed_headers all
|
||||
# need json.loads type
|
||||
# Should also remove nargs
|
||||
frontend_kwargs["allowed_origins"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_methods"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_headers"]["type"] = json.loads
|
||||
del frontend_kwargs["allowed_origins"]["nargs"]
|
||||
del frontend_kwargs["allowed_methods"]["nargs"]
|
||||
del frontend_kwargs["allowed_headers"]["nargs"]
|
||||
|
||||
# Special case: LoRA modules need custom parser action and
|
||||
# optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
|
||||
|
||||
# Special case: Middleware needs to append action
|
||||
frontend_kwargs["middleware"]["action"] = "append"
|
||||
frontend_kwargs["middleware"]["type"] = str
|
||||
if "nargs" in frontend_kwargs["middleware"]:
|
||||
del frontend_kwargs["middleware"]["nargs"]
|
||||
frontend_kwargs["middleware"]["default"] = []
|
||||
|
||||
# Special case: Tool call parser shows built-in options.
|
||||
valid_tool_parsers = list(ToolParserManager.tool_parsers.keys())
|
||||
parsers_str = ",".join(valid_tool_parsers)
|
||||
frontend_kwargs["tool_call_parser"]["metavar"] = (
|
||||
f"{{{parsers_str}}} or name registered in --tool-parser-plugin")
|
||||
|
||||
frontend_group = parser.add_argument_group(
|
||||
title="Frontend",
|
||||
description=FrontendArgs.__doc__,
|
||||
)
|
||||
|
||||
for key, value in frontend_kwargs.items():
|
||||
frontend_group.add_argument(f"--{key.replace('_', '-')}", **value)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Create the CLI argument parser used by the OpenAI API server.
|
||||
|
||||
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
|
||||
register all arguments instead of manually enumerating them here. This
|
||||
avoids code duplication and keeps the argument definitions in one place.
|
||||
"""
|
||||
parser.add_argument("model_tag",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="The model tag to serve "
|
||||
"(optional if specified in config)")
|
||||
parser.add_argument(
|
||||
"--headless",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run in headless mode. See multi-node data parallel "
|
||||
"documentation for more details.")
|
||||
parser.add_argument("--api-server-count",
|
||||
"-asc",
|
||||
type=int,
|
||||
default=1,
|
||||
help="How many API server processes to run.")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help="Read CLI options from a config file. "
|
||||
"Must be a YAML with the following options: "
|
||||
"https://docs.vllm.ai/en/latest/configuration/serve_args.html")
|
||||
parser = FrontendArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
if args.enable_log_outputs and not args.enable_log_requests:
|
||||
raise TypeError("Error: --enable-log-outputs requires "
|
||||
"--enable-log-requests")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server")
|
||||
return make_arg_parser(parser_for_docs)
|
||||
90
vllm/entrypoints/openai/logits_processors.py
Normal file
90
vllm/entrypoints/openai/logits_processors.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import lru_cache, partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
"""Logits processor for constraining generated tokens to a
|
||||
specific set of token ids."""
|
||||
|
||||
def __init__(self, allowed_ids: Iterable[int]):
|
||||
self.allowed_ids: Optional[list[int]] = list(allowed_ids)
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, token_ids: list[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.mask is None:
|
||||
self.mask = torch.ones((logits.shape[-1], ),
|
||||
dtype=torch.bool,
|
||||
device=logits.device)
|
||||
self.mask[self.allowed_ids] = False
|
||||
self.allowed_ids = None
|
||||
logits.masked_fill_(self.mask, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_allowed_token_ids_logits_processor(
|
||||
allowed_token_ids: frozenset[int],
|
||||
vocab_size: int,
|
||||
) -> LogitsProcessor:
|
||||
if not allowed_token_ids:
|
||||
raise ValueError("Empty allowed_token_ids provided")
|
||||
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains "
|
||||
"out-of-vocab token id")
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
logit_bias: dict[int, float],
|
||||
token_ids: list[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]],
|
||||
allowed_token_ids: Optional[list[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> list[LogitsProcessor]:
|
||||
logits_processors: list[LogitsProcessor] = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
clamped_logit_bias: dict[int, float] = {
|
||||
int(token_id): min(100.0, max(-100.0, bias))
|
||||
for token_id, bias in logit_bias.items()
|
||||
}
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Found token_id in logit_bias that is not "
|
||||
"an integer or string representing an integer") from exc
|
||||
|
||||
# Check if token_id is within the vocab size
|
||||
for token_id, bias in clamped_logit_bias.items():
|
||||
if token_id < 0 or token_id >= len(tokenizer):
|
||||
raise ValueError(f"token_id {token_id} in logit_bias contains "
|
||||
"out-of-vocab token id")
|
||||
|
||||
logits_processors.append(
|
||||
partial(logit_bias_logits_processor, clamped_logit_bias))
|
||||
|
||||
if allowed_token_ids is not None:
|
||||
logits_processors.append(
|
||||
_get_allowed_token_ids_logits_processor(
|
||||
frozenset(allowed_token_ids), len(tokenizer)))
|
||||
|
||||
return logits_processors
|
||||
2757
vllm/entrypoints/openai/protocol.py
Normal file
2757
vllm/entrypoints/openai/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
491
vllm/entrypoints/openai/run_batch.py
Normal file
491
vllm/entrypoints/openai/run_batch.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
from typing import Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
RerankResponse, ScoreResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help=
|
||||
"The path or url to a single input file. Currently supports local file "
|
||||
"paths, or the http protocol (http or https). If a URL is specified, "
|
||||
"the file should be available via HTTP GET.")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The path or url to a single output file. Currently supports "
|
||||
"local file paths, or web (http or https) urls. If a URL is specified,"
|
||||
" the file should be available via HTTP PUT.")
|
||||
parser.add_argument(
|
||||
"--output-tmp-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory to store the output file before uploading it "
|
||||
"to the output URL.",
|
||||
)
|
||||
parser.add_argument("--response-role",
|
||||
type=optional_type(str),
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=True`.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument("--enable-metrics",
|
||||
action="store_true",
|
||||
help="Enable Prometheus metrics")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="URL to the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port number for the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible batch runner.")
|
||||
return make_arg_parser(parser).parse_args()
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
class BatchProgressTracker:
|
||||
|
||||
def __init__(self):
|
||||
self._total = 0
|
||||
self._pbar: Optional[tqdm] = None
|
||||
|
||||
def submitted(self):
|
||||
self._total += 1
|
||||
|
||||
def completed(self):
|
||||
if self._pbar:
|
||||
self._pbar.update()
|
||||
|
||||
def pbar(self) -> tqdm:
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
self._pbar = tqdm(total=self._total,
|
||||
unit="req",
|
||||
desc="Running batch",
|
||||
mininterval=5,
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT)
|
||||
return self._pbar
|
||||
|
||||
|
||||
async def read_file(path_or_url: str) -> str:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, \
|
||||
session.get(path_or_url) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
with open(path_or_url, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def write_local_file(output_path: str,
|
||||
batch_outputs: list[BatchRequestOutput]) -> None:
|
||||
"""
|
||||
Write the responses to a local file.
|
||||
output_path: The path to write the responses to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
"""
|
||||
# We should make this async, but as long as run_batch runs as a
|
||||
# standalone program, blocking the event loop won't affect performance.
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=f)
|
||||
|
||||
|
||||
async def upload_data(output_url: str, data_or_file: str,
|
||||
from_file: bool) -> None:
|
||||
"""
|
||||
Upload a local file to a URL.
|
||||
output_url: The URL to upload the file to.
|
||||
data_or_file: Either the data to upload or the path to the file to upload.
|
||||
from_file: If True, data_or_file is the path to the file to upload.
|
||||
"""
|
||||
# Timeout is a common issue when uploading large files.
|
||||
# We retry max_retries times before giving up.
|
||||
max_retries = 5
|
||||
# Number of seconds to wait before retrying.
|
||||
delay = 5
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# We increase the timeout to 1000 seconds to allow
|
||||
# for large files (default is 300).
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||||
total=1000)) as session:
|
||||
if from_file:
|
||||
with open(data_or_file, "rb") as file:
|
||||
async with session.put(output_url,
|
||||
data=file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to upload file.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}")
|
||||
else:
|
||||
async with session.put(output_url,
|
||||
data=data_or_file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to upload data.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}")
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries:
|
||||
logger.error(
|
||||
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
|
||||
attempt,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
|
||||
) from e
|
||||
|
||||
|
||||
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput],
|
||||
output_tmp_dir: str) -> None:
|
||||
"""
|
||||
Write batch_outputs to a file or upload to a URL.
|
||||
path_or_url: The path or URL to write batch_outputs to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
output_tmp_dir: The directory to store the output file before uploading it
|
||||
to the output URL.
|
||||
"""
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
if output_tmp_dir is None:
|
||||
logger.info("Writing outputs to memory buffer")
|
||||
output_buffer = StringIO()
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=output_buffer)
|
||||
output_buffer.seek(0)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(
|
||||
path_or_url,
|
||||
output_buffer.read().strip().encode("utf-8"),
|
||||
from_file=False,
|
||||
)
|
||||
else:
|
||||
# Write responses to a temporary file and then upload it to the URL.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=output_tmp_dir,
|
||||
prefix="tmp_batch_output_",
|
||||
suffix=".jsonl",
|
||||
) as f:
|
||||
logger.info("Writing outputs to temporary local file %s",
|
||||
f.name)
|
||||
await write_local_file(f.name, batch_outputs)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(path_or_url, f.name, from_file=True)
|
||||
else:
|
||||
logger.info("Writing outputs to local file %s", path_or_url)
|
||||
await write_local_file(path_or_url, batch_outputs)
|
||||
|
||||
|
||||
def make_error_request_output(request: BatchRequestInput,
|
||||
error_msg: str) -> BatchRequestOutput:
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=error_msg,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
|
||||
async def make_async_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
|
||||
return make_error_request_output(request, error_msg)
|
||||
|
||||
|
||||
async def run_request(serving_engine_func: Callable,
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
|
||||
RerankResponse),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=response.error.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
batch_output = make_error_request_output(
|
||||
request, error_msg="Request must not be sent in stream mode")
|
||||
|
||||
tracker.completed()
|
||||
return batch_output
|
||||
|
||||
|
||||
async def run_batch(
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
)
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if "generate" in supported_tasks else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if "embed" in supported_tasks else None
|
||||
|
||||
enable_serving_reranking = ("classify" in supported_tasks and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
|
||||
openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if ("embed" in supported_tasks or enable_serving_reranking) else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: list[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
chat_handler_fn = openai_serving_chat.create_chat_completion if \
|
||||
openai_serving_chat is not None else None
|
||||
if chat_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=
|
||||
"The model does not support Chat Completions API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(chat_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
embed_handler_fn = openai_serving_embedding.create_embedding if \
|
||||
openai_serving_embedding is not None else None
|
||||
if embed_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Embeddings API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/score"):
|
||||
score_handler_fn = openai_serving_scores.create_score if \
|
||||
openai_serving_scores is not None else None
|
||||
if score_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Scores API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/rerank"):
|
||||
rerank_handler_fn = openai_serving_scores.do_rerank if \
|
||||
openai_serving_scores is not None else None
|
||||
if rerank_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Rerank API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(rerank_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /score, /rerank ."
|
||||
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||
"score/rerank versions.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
await write_file(args.output_file, responses, args.output_tmp_dir)
|
||||
|
||||
|
||||
async def main(args: Namespace):
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
|
||||
disable_frontend_multiprocessing=False,
|
||||
) as engine_client:
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
|
||||
await run_batch(engine_client, vllm_config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(main(args))
|
||||
1597
vllm/entrypoints/openai/serving_chat.py
Normal file
1597
vllm/entrypoints/openai/serving_chat.py
Normal file
File diff suppressed because it is too large
Load Diff
173
vllm/entrypoints/openai/serving_classification.py
Normal file
173
vllm/entrypoints/openai/serving_classification.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
ErrorResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationMixin(OpenAIServing):
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
if isinstance(ctx.request.input, str) and not ctx.request.input:
|
||||
return self.create_error_response(
|
||||
"Input cannot be empty for classification",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request))
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
probs = classify_res.probs
|
||||
predicted_index = int(np.argmax(probs))
|
||||
label = getattr(self.model_config.hf_config, "id2label",
|
||||
{}).get(predicted_index)
|
||||
|
||||
item = ClassificationData(
|
||||
index=idx,
|
||||
label=label,
|
||||
probs=probs,
|
||||
num_classes=len(probs),
|
||||
)
|
||||
|
||||
items.append(item)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ClassificationResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(self,
|
||||
request: ClassificationRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens)
|
||||
|
||||
|
||||
class ServingClassification(ClassificationMixin):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
model_name = self.models.model_name()
|
||||
request_id = (f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request)}")
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify("classify", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return pooling_params
|
||||
692
vllm/entrypoints/openai/serving_completion.py
Normal file
692
vllm/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,692 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import as_list, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info(
|
||||
"Using default completion sampling params from %s: %s",
|
||||
source,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"Echo is unsupported with prompt embeds.")
|
||||
|
||||
if (request.prompt_logprobs is not None
|
||||
and request.prompt_embeds is not None):
|
||||
return self.create_error_response(
|
||||
"prompt_logprobs is not compatible with prompt embeds.")
|
||||
|
||||
request_id = (
|
||||
f"cmpl-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
# Mypy does not infer that engine_prompt will have only one of
|
||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||
# these as Union[object, the expected type], where it infers
|
||||
# object if engine_prompt is a subclass of one of the
|
||||
# typeddicts that defines both keys. Worse, because of
|
||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||
# infer the type of engine_prompt correctly because of the
|
||||
# enumerate. So we need an unnecessary cast here.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if is_embeds_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_embeds"])
|
||||
elif is_tokens_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=input_length,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. Noting that best_of is only supported in V0. In addition,
|
||||
# we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata,
|
||||
enable_force_include_usage=self.enable_force_include_usage,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
enable_force_include_usage: bool,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
num_cached_tokens = None
|
||||
first_iteration = True
|
||||
|
||||
stream_options = request.stream_options
|
||||
if stream_options:
|
||||
include_usage = (stream_options.include_usage
|
||||
or enable_force_include_usage)
|
||||
include_continuous_usage = (include_usage and
|
||||
stream_options.continuous_usage_stats)
|
||||
else:
|
||||
include_usage, include_continuous_usage = False, False
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
|
||||
if first_iteration:
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
first_iteration = False
|
||||
|
||||
prompt_text = res.prompt
|
||||
if prompt_text is None:
|
||||
engine_prompt = engine_prompts[prompt_idx]
|
||||
prompt_text = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[dict[
|
||||
int, Logprob]]]]
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
|
||||
# Useful when request.return_token_ids is True
|
||||
# Returning prompt token IDs shares the same logic
|
||||
# with the echo implementation.
|
||||
prompt_token_ids_to_return: Optional[list[int]] = None
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and not has_echoed[i]:
|
||||
assert prompt_token_ids is not None
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
else:
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids,
|
||||
*output.token_ids,
|
||||
]
|
||||
out_logprobs = [
|
||||
*(prompt_logprobs or []),
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# has_echoed[i] is reused here to indicate whether
|
||||
# we have already returned the prompt token IDs.
|
||||
if not has_echoed[i]:
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
|
||||
if (not delta_text and not delta_token_ids
|
||||
and not previous_num_tokens[i]):
|
||||
# Chunked prefill case, don't return empty chunks
|
||||
continue
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
return_as_token_id=request.
|
||||
return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
prompt_token_ids=prompt_token_ids_to_return,
|
||||
token_ids=(as_list(output.token_ids) if
|
||||
request.return_token_ids else None),
|
||||
)
|
||||
],
|
||||
)
|
||||
if include_continuous_usage:
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
response_json = chunk.model_dump_json(exclude_unset=False)
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
total_prompt_tokens = sum(num_prompt_tokens)
|
||||
total_completion_tokens = sum(previous_num_tokens)
|
||||
final_usage_info = UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
)
|
||||
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens)
|
||||
|
||||
if include_usage:
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=final_usage_info,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_unset=False, exclude_none=True)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = final_usage_info
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
self,
|
||||
final_res_batch: list[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
kv_transfer_params = None
|
||||
last_final_res = None
|
||||
for final_res in final_res_batch:
|
||||
last_final_res = final_res
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[dict[int,
|
||||
Logprob]]]]
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo:
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
else:
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
prompt_token_ids=(prompt_token_ids
|
||||
if request.return_token_ids else None),
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
if (self.enable_prompt_tokens_details and last_final_res
|
||||
and last_final_res.num_cached_tokens):
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=last_final_res.num_cached_tokens)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
if final_res_batch:
|
||||
kv_transfer_params = final_res_batch[0].kv_transfer_params
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
return_as_token_id: Optional[bool] = None,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: list[int] = []
|
||||
out_token_logprobs: list[Optional[float]] = []
|
||||
out_tokens: list[str] = []
|
||||
out_top_logprobs: list[Optional[dict[str, float]]] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
should_return_as_token_id = (return_as_token_id
|
||||
if return_as_token_id is not None else
|
||||
self.return_tokens_as_token_ids)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if should_return_as_token_id:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(
|
||||
top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
})
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
max_input_length: Optional[int] = None,
|
||||
) -> RenderConfig:
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
|
||||
return RenderConfig(
|
||||
max_length=max_input_tokens_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo
|
||||
and not request.return_token_ids),
|
||||
)
|
||||
631
vllm/entrypoints/openai/serving_embedding.py
Normal file
631
vllm/entrypoints/openai/serving_embedding.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Final, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext,
|
||||
TextTokensPrompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingOutput, PoolingRequestOutput, RequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import chunk_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[list[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.embedding
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
class EmbeddingMixin(OpenAIServing):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
pooler_config = self.model_config.pooler_config
|
||||
|
||||
# Avoid repeated attribute lookups
|
||||
self.supports_chunked_processing = bool(
|
||||
pooler_config and pooler_config.enable_chunked_processing)
|
||||
self.max_embed_len = (pooler_config.max_embed_len if pooler_config
|
||||
and pooler_config.max_embed_len else None)
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template
|
||||
or ctx.chat_template,
|
||||
chat_template_content_format=ctx.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=ctx.request.add_generation_prompt,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request),
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_render_config(
|
||||
self, request: EmbeddingCompletionRequest) -> RenderConfig:
|
||||
# Set max_length based on chunked processing capability
|
||||
if self._should_use_chunked_processing(request):
|
||||
max_length = None
|
||||
else:
|
||||
max_length = self.max_embed_len or self.max_model_len
|
||||
|
||||
return RenderConfig(
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
items: list[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
embedding_res = EmbeddingRequestOutput.from_base(final_res)
|
||||
|
||||
item = EmbeddingResponseData(
|
||||
index=idx,
|
||||
embedding=_get_embedding(embedding_res.outputs,
|
||||
ctx.request.encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _get_max_position_embeddings(self) -> int:
|
||||
"""Get the model's effective maximum sequence length for chunking."""
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def _should_use_chunked_processing(self, request) -> bool:
|
||||
"""Check if chunked processing should be used for this request."""
|
||||
return isinstance(
|
||||
request,
|
||||
(EmbeddingCompletionRequest,
|
||||
EmbeddingChatRequest)) and self.supports_chunked_processing
|
||||
|
||||
async def _process_chunked_request(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
original_prompt: TextTokensPrompt,
|
||||
pooling_params,
|
||||
trace_headers,
|
||||
prompt_idx: int,
|
||||
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
|
||||
"""Process a single prompt using chunked processing."""
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
token_ids = original_prompt["prompt_token_ids"]
|
||||
|
||||
# Split into chunks using max_position_embeddings
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
# Process all chunks for MEAN aggregation
|
||||
for chunk_idx, chunk_tokens in enumerate(
|
||||
chunk_list(token_ids, max_pos_embeddings)):
|
||||
# Create a request ID for this chunk
|
||||
chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-"
|
||||
f"chunk-{chunk_idx}")
|
||||
|
||||
# Create engine prompt for this chunk
|
||||
chunk_engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=chunk_tokens)
|
||||
|
||||
# Create chunk request prompt for logging
|
||||
chunk_text = ""
|
||||
chunk_request_prompt = TextTokensPrompt(
|
||||
prompt=chunk_text, prompt_token_ids=chunk_tokens)
|
||||
|
||||
# Log the chunk
|
||||
self._log_inputs(chunk_request_id,
|
||||
chunk_request_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Create generator for this chunk and wrap it to return indices
|
||||
original_generator = self.engine_client.encode(
|
||||
chunk_engine_prompt,
|
||||
pooling_params,
|
||||
chunk_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(original_generator)
|
||||
|
||||
return generators
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
"""Override to support chunked processing for embedding requests."""
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingCompletionRequest, EmbeddingChatRequest)):
|
||||
# Check if chunked processing is enabled for pooling models
|
||||
enable_chunked = self._should_use_chunked_processing(request)
|
||||
|
||||
# Use max_position_embeddings for chunked processing decisions
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
# Determine the effective max length for validation
|
||||
if self.max_embed_len is not None:
|
||||
# Use max_embed_len for validation instead of max_model_len
|
||||
length_type = "maximum embedding input length"
|
||||
max_length_value = self.max_embed_len
|
||||
else:
|
||||
# Fall back to max_model_len validation (original behavior)
|
||||
length_type = "maximum context length"
|
||||
max_length_value = self.max_model_len
|
||||
|
||||
validation_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input.")
|
||||
|
||||
chunked_processing_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input "
|
||||
"or enable chunked processing.")
|
||||
|
||||
# Check if input exceeds max length
|
||||
if token_num > max_length_value:
|
||||
raise ValueError(
|
||||
validation_error_msg.format(
|
||||
length_type=length_type,
|
||||
max_length_value=max_length_value,
|
||||
token_num=token_num))
|
||||
|
||||
# Check for chunked processing
|
||||
# when exceeding max_position_embeddings
|
||||
if token_num > max_pos_embeddings:
|
||||
if enable_chunked:
|
||||
# Allow long inputs when chunked processing is enabled
|
||||
logger.info(
|
||||
"Input length %s exceeds max_position_embeddings "
|
||||
"%s, will use chunked processing", token_num,
|
||||
max_pos_embeddings)
|
||||
else:
|
||||
raise ValueError(
|
||||
chunked_processing_error_msg.format(
|
||||
length_type="maximum position embeddings length",
|
||||
max_length_value=max_pos_embeddings,
|
||||
token_num=token_num))
|
||||
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# For other request types, use the parent's implementation
|
||||
return super()._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _is_text_tokens_prompt(self, prompt) -> bool:
|
||||
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: EngineTokensPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Optional[Mapping[str, str]],
|
||||
prompt_index: int,
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
"""Create a generator for a single prompt using standard processing."""
|
||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Override to support chunked processing."""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
|
||||
# Check if we should use chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
# If no chunked processing needed, delegate to parent class
|
||||
if not use_chunked:
|
||||
return await super()._prepare_generators(ctx)
|
||||
|
||||
# Custom logic for chunked processing
|
||||
generators: list[AsyncGenerator[Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
# Verify and set the task for pooling params
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
# Check if this specific prompt needs chunked processing
|
||||
if self._is_text_tokens_prompt(engine_prompt):
|
||||
# Cast to TextTokensPrompt since we've verified
|
||||
# prompt_token_ids
|
||||
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
|
||||
if (len(text_tokens_prompt["prompt_token_ids"])
|
||||
> max_pos_embeddings):
|
||||
# Use chunked processing for this prompt
|
||||
chunk_generators = await self._process_chunked_request(
|
||||
ctx, text_tokens_prompt, pooling_params,
|
||||
trace_headers, i)
|
||||
generators.extend(chunk_generators)
|
||||
continue
|
||||
|
||||
# Normal processing for short prompts or non-token prompts
|
||||
generator = await self._create_single_prompt_generator(
|
||||
ctx, engine_prompt, pooling_params, trace_headers, i)
|
||||
generators.append(generator)
|
||||
|
||||
from vllm.utils import merge_async_iterators
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect and aggregate batch results
|
||||
with support for chunked processing.
|
||||
|
||||
For chunked requests, performs online aggregation to
|
||||
minimize memory usage.
|
||||
For regular requests, collects results normally.
|
||||
"""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
# Check if we used chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
if not use_chunked:
|
||||
return await super()._collect_batch(ctx=ctx)
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
prompt_aggregators: dict[int, dict[str, Any]] = {}
|
||||
short_prompts_results: dict[int, PoolingRequestOutput] = {}
|
||||
|
||||
async for result_idx, result in ctx.result_generator:
|
||||
if "-chunk-" in result.request_id:
|
||||
# Extract prompt_idx from chunked request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
prompt_idx = int(parts[parts.index("prompt") + 1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback: extract from result_idx if parsing fails
|
||||
prompt_idx = result_idx
|
||||
|
||||
# Initialize aggregator for this prompt if needed
|
||||
if prompt_idx not in prompt_aggregators:
|
||||
prompt_aggregators[prompt_idx] = {
|
||||
'weighted_sum': None,
|
||||
'total_weight': 0,
|
||||
'chunk_count': 0,
|
||||
'request_id': result.request_id.split("-chunk-")[0]
|
||||
}
|
||||
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
# MEAN pooling with online weighted averaging
|
||||
# Ensure result is PoolingRequestOutput
|
||||
# for embedding processing
|
||||
if not isinstance(result, PoolingRequestOutput):
|
||||
return self.create_error_response(
|
||||
f"Expected PoolingRequestOutput for "
|
||||
f"chunked embedding, got "
|
||||
f"{type(result).__name__}")
|
||||
|
||||
# Handle both PoolingOutput and
|
||||
# EmbeddingOutput types
|
||||
if hasattr(result.outputs, 'data'):
|
||||
# PoolingOutput case
|
||||
embedding_data = result.outputs.data
|
||||
elif hasattr(result.outputs, 'embedding'):
|
||||
# EmbeddingOutput case -
|
||||
# convert embedding list to tensor
|
||||
embedding_data = result.outputs.embedding
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Unsupported output type: "
|
||||
f"{type(result.outputs).__name__}")
|
||||
|
||||
if not isinstance(embedding_data, torch.Tensor):
|
||||
embedding_data = torch.tensor(embedding_data,
|
||||
dtype=torch.float32)
|
||||
|
||||
if result.prompt_token_ids is None:
|
||||
return self.create_error_response(
|
||||
"prompt_token_ids cannot be None for "
|
||||
"chunked processing")
|
||||
weight = len(result.prompt_token_ids)
|
||||
|
||||
weighted_embedding = embedding_data.to(
|
||||
dtype=torch.float32) * weight
|
||||
|
||||
if aggregator['weighted_sum'] is None:
|
||||
# First chunk
|
||||
aggregator['weighted_sum'] = weighted_embedding
|
||||
else:
|
||||
# Accumulate
|
||||
aggregator['weighted_sum'] += weighted_embedding
|
||||
|
||||
aggregator['total_weight'] += weight
|
||||
aggregator['chunk_count'] += 1
|
||||
else:
|
||||
# Non-chunked result - extract prompt_idx from request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
# Last part should be prompt index
|
||||
prompt_idx = int(parts[-1])
|
||||
except (ValueError, IndexError):
|
||||
prompt_idx = result_idx # Fallback to result_idx
|
||||
|
||||
short_prompts_results[prompt_idx] = cast(
|
||||
PoolingRequestOutput, result)
|
||||
|
||||
# Finalize aggregated results
|
||||
final_res_batch: list[Union[PoolingRequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
if prompt_idx in prompt_aggregators:
|
||||
# Finalize MEAN aggregation for this chunked prompt
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
weighted_sum = aggregator['weighted_sum']
|
||||
total_weight = aggregator['total_weight']
|
||||
|
||||
if (weighted_sum is not None
|
||||
and isinstance(weighted_sum, torch.Tensor)
|
||||
and isinstance(total_weight,
|
||||
(int, float)) and total_weight > 0):
|
||||
|
||||
# Compute final mean embedding
|
||||
final_embedding = weighted_sum / total_weight
|
||||
|
||||
# Create a PoolingRequestOutput
|
||||
# for the aggregated result
|
||||
pooling_output_data = PoolingOutput(
|
||||
data=final_embedding)
|
||||
|
||||
# Get original prompt token IDs for this prompt
|
||||
original_prompt = ctx.engine_prompts[prompt_idx]
|
||||
if not self._is_text_tokens_prompt(original_prompt):
|
||||
return self.create_error_response(
|
||||
f"Chunked prompt {prompt_idx} is not a "
|
||||
f"TextTokensPrompt")
|
||||
|
||||
original_token_ids = cast(
|
||||
TextTokensPrompt,
|
||||
original_prompt)["prompt_token_ids"]
|
||||
|
||||
pooling_request_output = PoolingRequestOutput(
|
||||
request_id=aggregator['request_id'],
|
||||
prompt_token_ids=original_token_ids,
|
||||
outputs=pooling_output_data,
|
||||
finished=True)
|
||||
|
||||
final_res_batch.append(pooling_request_output)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Failed to aggregate chunks "
|
||||
f"for prompt {prompt_idx}")
|
||||
elif prompt_idx in short_prompts_results:
|
||||
final_res_batch.append(
|
||||
cast(PoolingRequestOutput,
|
||||
short_prompts_results[prompt_idx]))
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Result not found for prompt {prompt_idx}")
|
||||
|
||||
ctx.final_res_batch = cast(
|
||||
list[Union[RequestOutput, PoolingRequestOutput]],
|
||||
final_res_batch)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
request_id_prefix = "embd"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||
|
||||
ctx = EmbeddingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext[EmbeddingRequest],
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return pooling_params
|
||||
992
vllm/entrypoints/openai/serving_engine.py
Normal file
992
vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,992 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.context import ConversationContext
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorInfo,
|
||||
ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
PoolingResponse, RerankRequest,
|
||||
ResponsesRequest, ScoreRequest,
|
||||
ScoreResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranslationRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
|
||||
RenderConfig)
|
||||
# yapf: enable
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
|
||||
MultiModalDataDict, MultiModalUUIDDict)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
||||
merge_async_iterators, random_uuid)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CompletionLikeRequest = Union[
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
RerankRequest,
|
||||
ClassificationRequest,
|
||||
ScoreRequest,
|
||||
TokenizeCompletionRequest,
|
||||
]
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
|
||||
AnyRequest = Union[
|
||||
CompletionLikeRequest,
|
||||
ChatLikeRequest,
|
||||
SpeechToTextRequest,
|
||||
ResponsesRequest,
|
||||
IOProcessorRequest,
|
||||
]
|
||||
|
||||
AnyResponse = Union[
|
||||
CompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
TranscriptionResponse,
|
||||
TokenizeResponse,
|
||||
PoolingResponse,
|
||||
ClassificationResponse,
|
||||
ScoreResponse,
|
||||
]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
|
||||
|
||||
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
|
||||
|
||||
|
||||
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
|
||||
and "prompt_embeds" in prompt)
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
class RequestProcessingMixin(BaseModel):
|
||||
"""
|
||||
Mixin for request processing,
|
||||
handling prompt preparation and engine input.
|
||||
"""
|
||||
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
||||
engine_prompts: Optional[list[EngineTokensPrompt]] = []
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ResponseGenerationMixin(BaseModel):
|
||||
"""
|
||||
Mixin for response generation,
|
||||
managing result generators and final batch results.
|
||||
"""
|
||||
|
||||
result_generator: Optional[AsyncGenerator[tuple[int, Union[
|
||||
RequestOutput, PoolingRequestOutput]], None]] = None
|
||||
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
|
||||
default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ServeContext(
|
||||
RequestProcessingMixin,
|
||||
ResponseGenerationMixin,
|
||||
BaseModel,
|
||||
Generic[RequestT],
|
||||
):
|
||||
# Shared across all requests
|
||||
request: RequestT
|
||||
raw_request: Optional[Request] = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = Field(default_factory=lambda: int(time.time()))
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
|
||||
# `protected_namespaces` resolves Pydantic v2's warning
|
||||
# on conflict with protected namespace "model_"
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
ClassificationServeContext = ServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
|
||||
chat_template: Optional[str] = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption
|
||||
|
||||
|
||||
# Used to resolve the Pydantic error related to
|
||||
# forward reference of MultiModalDataDict in TokensPrompt
|
||||
RequestProcessingMixin.model_rebuild()
|
||||
ServeContext.model_rebuild()
|
||||
ClassificationServeContext.model_rebuild()
|
||||
EmbeddingServeContext.model_rebuild()
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
so you can easily tell “this ID came from Embedding vs Classification.”
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.models = models
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self._async_tokenizer_pool: dict[AnyTokenizer,
|
||||
AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
||||
"""
|
||||
Get a Renderer instance with the provided tokenizer.
|
||||
Uses shared async tokenizer pool for efficiency.
|
||||
"""
|
||||
return CompletionRenderer(
|
||||
model_config=self.model_config,
|
||||
tokenizer=tokenizer,
|
||||
async_tokenizer_pool=self._async_tokenizer_pool)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: Any,
|
||||
) -> RenderConfig:
|
||||
"""
|
||||
Build and return a `RenderConfig` for an endpoint.
|
||||
|
||||
Used by the renderer to control how prompts are prepared
|
||||
(e.g., tokenization and length handling). Endpoints should
|
||||
implement this with logic appropriate to their request type.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||
"""
|
||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||
given tokenizer.
|
||||
"""
|
||||
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
||||
if async_tokenizer is None:
|
||||
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||
self._async_tokenizer_pool[tokenizer] = async_tokenizer
|
||||
return async_tokenizer
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override
|
||||
to prepare `ctx` (classification, embedding, etc.).
|
||||
"""
|
||||
return None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
"""
|
||||
Default response builder. Subclass may override this method
|
||||
to return the appropriate response object.
|
||||
"""
|
||||
return self.create_error_response("unimplemented endpoint")
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
|
||||
generation = self._pipeline(ctx)
|
||||
|
||||
async for response in generation:
|
||||
return response
|
||||
|
||||
return self.create_error_response("No response yielded from pipeline")
|
||||
|
||||
async def _pipeline(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
|
||||
"""Execute the request processing pipeline yielding responses."""
|
||||
if error := await self._check_model(ctx.request):
|
||||
yield error
|
||||
if error := self._validate_request(ctx):
|
||||
yield error
|
||||
|
||||
preprocess_ret = await self._preprocess(ctx)
|
||||
if isinstance(preprocess_ret, ErrorResponse):
|
||||
yield preprocess_ret
|
||||
|
||||
generators_ret = await self._prepare_generators(ctx)
|
||||
if isinstance(generators_ret, ErrorResponse):
|
||||
yield generators_ret
|
||||
|
||||
collect_ret = await self._collect_batch(ctx)
|
||||
if isinstance(collect_ret, ErrorResponse):
|
||||
yield collect_ret
|
||||
|
||||
yield self._build_response(ctx)
|
||||
|
||||
def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if (truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.max_model_len):
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
return None
|
||||
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
if not hasattr(ctx.request, "to_pooling_params"):
|
||||
return self.create_error_response(
|
||||
"Request type does not support pooling parameters")
|
||||
|
||||
return ctx.request.to_pooling_params()
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Schedule the request and get the result generator."""
|
||||
generators: list[AsyncGenerator[Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect batch results from the result generator."""
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[Optional[Union[RequestOutput,
|
||||
PoolingRequestOutput]]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
return self.create_error_response(
|
||||
"Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [
|
||||
res for res in final_res_batch if res is not None
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
) -> ErrorResponse:
|
||||
if self.log_error_stack:
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type is not None:
|
||||
traceback.print_exc()
|
||||
else:
|
||||
traceback.print_stack()
|
||||
return ErrorResponse(error=ErrorInfo(
|
||||
message=message, type=err_type, code=status_code.value))
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
) -> str:
|
||||
json_str = json.dumps(
|
||||
self.create_error_response(message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code).model_dump())
|
||||
return json_str
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
error_response = None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in self.models.lora_requests:
|
||||
return None
|
||||
if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
|
||||
(load_result := await self.models.resolve_lora(request.model))):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if (isinstance(load_result, ErrorResponse) and
|
||||
load_result.error.code == HTTPStatus.BAD_REQUEST.value):
|
||||
error_response = load_result
|
||||
|
||||
return error_response or self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyRequest) -> Optional[LoRARequest]:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
supports_default_mm_loras: bool = False,
|
||||
) -> Optional[LoRARequest]:
|
||||
if request.model in self.models.lora_requests:
|
||||
return self.models.lora_requests[request.model]
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
return default_mm_lora
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_message_types(self, request: AnyRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
for message in request.messages:
|
||||
if (isinstance(message, dict) and "content" in message
|
||||
and isinstance(message["content"], list)):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
async def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
"do_lower_case", False)):
|
||||
prompt = prompt.lower()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = await async_tokenizer(
|
||||
prompt, add_special_tokens=add_special_tokens)
|
||||
elif truncate_prompt_tokens < 0:
|
||||
# Negative means we cap at the model's max length
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=self.max_model_len,
|
||||
)
|
||||
else:
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
input_ids = encoded.input_ids
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
async def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt_ids: list[int],
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
) -> TextTokensPrompt:
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
elif truncate_prompt_tokens < 0:
|
||||
input_ids = prompt_ids[-self.max_model_len:]
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
if tokenizer is None:
|
||||
input_text = ""
|
||||
else:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
input_text = await async_tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
(
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ScoreRequest,
|
||||
RerankRequest,
|
||||
ClassificationRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
# since these requests don't generate tokens.
|
||||
if token_num > self.max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreRequest: "score",
|
||||
ClassificationRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request),
|
||||
"embedding generation")
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(
|
||||
request,
|
||||
(TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest),
|
||||
):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages.")
|
||||
|
||||
if (max_tokens is not None
|
||||
and token_num + max_tokens > self.max_model_len):
|
||||
raise ValueError(
|
||||
"'max_tokens' or 'max_completion_tokens' is too large: "
|
||||
f"{max_tokens}. This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens and your request has "
|
||||
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
|
||||
f" - {token_num}).")
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
async def _tokenize_prompt_input_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, list[int]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
A simpler implementation that tokenizes a single prompt input.
|
||||
"""
|
||||
async for result in self._tokenize_prompt_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
add_special_tokens=add_special_tokens,
|
||||
):
|
||||
return result
|
||||
raise ValueError("No results yielded from tokenization")
|
||||
|
||||
async def _tokenize_prompt_inputs_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> AsyncGenerator[TextTokensPrompt, None]:
|
||||
"""
|
||||
A simpler implementation that tokenizes multiple prompt inputs.
|
||||
"""
|
||||
for prompt in prompt_inputs:
|
||||
if isinstance(prompt, str):
|
||||
yield await self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
prompt=prompt,
|
||||
tokenizer=tokenizer,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield await self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: Union[ChatLikeRequest, ResponsesRequest],
|
||||
tokenizer: AnyTokenizer,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tool_dicts: Optional[list[dict[str, Any]]] = None,
|
||||
documents: Optional[list[dict[str, str]]] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Sequence[RequestPrompt],
|
||||
list[EngineTokensPrompt],
|
||||
]:
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tool_dicts,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
|
||||
messages,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
_chat_template_kwargs: dict[str, Any] = dict(
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=documents,
|
||||
)
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
request_prompt: Union[str, list[int]]
|
||||
|
||||
if tokenizer is None:
|
||||
request_prompt = "placeholder"
|
||||
elif isinstance(tokenizer, MistralTokenizer):
|
||||
request_prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
else:
|
||||
request_prompt = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
conversation=conversation,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
|
||||
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
|
||||
should_parse_tools = tool_parser is not None and (hasattr(
|
||||
request, "tool_choice") and request.tool_choice != "none")
|
||||
|
||||
if should_parse_tools:
|
||||
if not isinstance(request, ChatCompletionRequest):
|
||||
msg = "Tool usage is only supported for Chat Completions API"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
request = tool_parser(tokenizer).adjust_request( # type: ignore
|
||||
request=request)
|
||||
|
||||
if tokenizer is None:
|
||||
assert isinstance(request_prompt, str), (
|
||||
"Prompt has to be a string",
|
||||
"when the tokenizer is not initialised",
|
||||
)
|
||||
prompt_inputs = TextTokensPrompt(prompt=request_prompt,
|
||||
prompt_token_ids=[1])
|
||||
elif isinstance(request_prompt, str):
|
||||
prompt_inputs = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request_prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For MistralTokenizer
|
||||
assert is_list_of(request_prompt, int), (
|
||||
"Prompt has to be either a string or a list of token ids")
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(request_prompt),
|
||||
prompt_token_ids=request_prompt,
|
||||
)
|
||||
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
|
||||
return conversation, [request_prompt], [engine_prompt]
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
request_prompt: RequestPrompt,
|
||||
engine_prompt: EngineTokensPrompt,
|
||||
sampling_params: SamplingParams,
|
||||
context: ConversationContext,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
priority: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
orig_priority = priority
|
||||
while True:
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
request_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
async for res in generator:
|
||||
context.append_output(res)
|
||||
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||
yield context
|
||||
|
||||
if not context.need_builtin_tool_call():
|
||||
# The model did not ask for a tool call, so we're done.
|
||||
break
|
||||
|
||||
# Call the tool and update the context with the result.
|
||||
tool_output = await context.call_tool()
|
||||
context.append_output(tool_output)
|
||||
|
||||
# TODO: uncomment this and enable tool output streaming
|
||||
# yield context
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids.
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=prompt_token_ids)
|
||||
request_prompt = prompt_token_ids
|
||||
# Update the sampling params.
|
||||
sampling_params.max_tokens = self.max_model_len - len(
|
||||
prompt_token_ids)
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[RequestPrompt, PromptType],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
elif isinstance(inputs, list):
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = getattr(inputs, 'prompt', None)
|
||||
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Optional[Mapping[str, str]]:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(raw_request: Optional[Request],
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
default = default or random_uuid()
|
||||
if raw_request is None:
|
||||
return default
|
||||
|
||||
return raw_request.headers.get("X-Request-Id", default)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
return_as_token_id: bool = False,
|
||||
) -> str:
|
||||
if return_as_token_id:
|
||||
return f"token_id:{token_id}"
|
||||
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
def _is_model_supported(self, model_name: Optional[str]) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
|
||||
def clamp_prompt_logprobs(
|
||||
prompt_logprobs: Union[PromptLogprobs,
|
||||
None], ) -> Union[PromptLogprobs, None]:
|
||||
if prompt_logprobs is None:
|
||||
return prompt_logprobs
|
||||
|
||||
for logprob_dict in prompt_logprobs:
|
||||
if logprob_dict is None:
|
||||
continue
|
||||
for logprob_values in logprob_dict.values():
|
||||
if logprob_values.logprob == float("-inf"):
|
||||
logprob_values.logprob = -9999.0
|
||||
return prompt_logprobs
|
||||
288
vllm/entrypoints/openai/serving_models.py
Normal file
288
vllm/entrypoints/openai/serving_models.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIServingModels:
|
||||
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||
|
||||
Handles the routes:
|
||||
- /v1/models
|
||||
- /v1/load_lora_adapter
|
||||
- /v1/unload_lora_adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: list[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[list[LoRAModulePath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
|
||||
):
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name))
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
if self.static_lora_modules is None:
|
||||
return
|
||||
for lora in self.static_lora_modules:
|
||||
load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
|
||||
lora_name=lora.name)
|
||||
load_result = await self.load_lora_adapter(
|
||||
request=load_request, base_model_name=lora.base_model_name)
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.error.message)
|
||||
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
|
||||
"""Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora_request is not None:
|
||||
return lora_request.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all
|
||||
adapters"""
|
||||
model_cards = [
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests.values()
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoRAAdapterRequest,
|
||||
base_model_name: Optional[str] = None
|
||||
) -> Union[ErrorResponse, str]:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(
|
||||
request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_path = request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
lora_request = LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path)
|
||||
if base_model_name is not None and self.is_base_model(
|
||||
base_model_name):
|
||||
lora_request.base_model_name = base_model_name
|
||||
|
||||
# Validate that the adapter can be loaded into the engine
|
||||
# This will also pre-load it for incoming requests
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
except Exception as e:
|
||||
error_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if "No adapter found" in str(e):
|
||||
error_type = "NotFoundError"
|
||||
status_code = HTTPStatus.NOT_FOUND
|
||||
|
||||
return create_error_response(message=str(e),
|
||||
err_type=error_type,
|
||||
status_code=status_code)
|
||||
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info("Loaded new LoRA adapter: name '%s', path '%s'",
|
||||
lora_name, lora_path)
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(
|
||||
request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# Safe to delete now since we hold the lock
|
||||
del self.lora_requests[lora_name]
|
||||
logger.info("Removed LoRA adapter: name '%s'", lora_name)
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if request.lora_name in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been "
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self,
|
||||
request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if 'lora_name' is not provided return an error
|
||||
if not request.lora_name:
|
||||
return create_error_response(
|
||||
message=
|
||||
"'lora_name' needs to be provided to unload a LoRA adapter.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if request.lora_name not in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(
|
||||
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
if lora_name in self.lora_requests:
|
||||
return self.lora_requests[lora_name]
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(
|
||||
base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name, resolver.__class__.__name__)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.", lora_name,
|
||||
resolver.__class__.__name__, e)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(f"LoRA adapter '{lora_name}' was found "
|
||||
"but could not be loaded."),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(error=ErrorInfo(
|
||||
message=message, type=err_type, code=status_code.value))
|
||||
276
vllm/entrypoints/openai/serving_pooling.py
Normal file
276
vllm/entrypoints/openai/serving_pooling.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_data(
|
||||
output: PoolingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[list[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.data.tolist()
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
pt_float32 = output.data.to(dtype=torch.float32)
|
||||
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
|
||||
return base64.b64encode(pooling_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
io_processor_plugin = self.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = self.models.model_name()
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
is_io_processor_request = isinstance(request, IOProcessorRequest)
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id)
|
||||
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
# In pooling requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.input,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("encode", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
model_output=result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
return self.io_processor.output_to_response(output)
|
||||
|
||||
assert isinstance(request,
|
||||
(PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_pooling_response(
|
||||
final_res_batch_checked,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
request.encoding_format,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_pooling_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> PoolingResponse:
|
||||
items: list[PoolingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
item = PoolingResponseData(
|
||||
index=idx,
|
||||
data=_get_data(final_res.outputs, encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return PoolingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self, request: PoolingCompletionRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
1709
vllm/entrypoints/openai/serving_responses.py
Normal file
1709
vllm/entrypoints/openai/serving_responses.py
Normal file
File diff suppressed because it is too large
Load Diff
479
vllm/entrypoints/openai/serving_score.py
Normal file
479
vllm/entrypoints/openai/serving_score.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage,
|
||||
ScoreRequest, ScoreResponse,
|
||||
ScoreResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
_validate_score_input_lens,
|
||||
compress_token_type_ids,
|
||||
get_score_prompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
input_texts = texts_1 + texts_2
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts))
|
||||
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
|
||||
text_token_prompt = \
|
||||
self._validate_input(
|
||||
request,
|
||||
tok_result["input_ids"],
|
||||
input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"]))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: list[Optional[PoolingRequestOutput]] =\
|
||||
[None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_texts_1: list[PoolingRequestOutput] = []
|
||||
emb_texts_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(texts_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_1.append(emb)
|
||||
|
||||
for i in range(len(texts_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_2.append(emb)
|
||||
|
||||
if len(emb_texts_1) == 1:
|
||||
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
|
||||
embed_1=emb_texts_1,
|
||||
embed_2=emb_texts_2)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
def _preprocess_score(
|
||||
self,
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: Union[str, ScoreContentPartParam],
|
||||
data_2: Union[str, ScoreContentPartParam],
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=data_1,
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
self._validate_input(request, engine_prompt["prompt_token_ids"],
|
||||
full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
data_1: Union[list[str], list[ScoreContentPartParam]],
|
||||
data_2: Union[list[str], list[ScoreContentPartParam]],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
preprocess_async = make_async(self._preprocess_score,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
preprocessed_prompts = await asyncio.gather(
|
||||
*(preprocess_async(request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_1=t1,
|
||||
data_2=t2) for t1, t2 in input_pairs))
|
||||
|
||||
for full_prompt, engine_prompt in preprocessed_prompts:
|
||||
request_prompts.append(full_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
default_pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
default_pooling_params.verify("score", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=default_pooling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
|
||||
pooling_params = default_pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
pooling_params.extra_kwargs = {
|
||||
"compressed_token_type_ids": compressed
|
||||
}
|
||||
else:
|
||||
pooling_params = (default_pooling_params)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[
|
||||
Optional[PoolingRequestOutput]] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
data_1: Union[list[str], str, ScoreMultiModalParam],
|
||||
data_2: Union[list[str], str, ScoreMultiModalParam],
|
||||
request: Union[ScoreRequest, RerankRequest],
|
||||
request_id: str,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
if not self.model_config.is_multimodal_model and (isinstance(
|
||||
data_1, dict) or isinstance(data_2, dict)):
|
||||
raise ValueError(
|
||||
f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501
|
||||
)
|
||||
|
||||
if isinstance(data_1, str):
|
||||
data_1 = [data_1]
|
||||
elif isinstance(data_1, dict):
|
||||
data_1 = data_1.get("content") # type: ignore[assignment]
|
||||
|
||||
if isinstance(data_2, str):
|
||||
data_2 = [data_2]
|
||||
elif isinstance(data_2, dict):
|
||||
data_2 = data_2.get("content") # type: ignore[assignment]
|
||||
|
||||
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
|
||||
|
||||
if self.model_config.is_cross_encoder:
|
||||
return await self._cross_encoding_score(
|
||||
tokenizer=tokenizer,
|
||||
data_1=data_1, # type: ignore[arg-type]
|
||||
data_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
else:
|
||||
return await self._embedding_score(
|
||||
tokenizer=tokenizer,
|
||||
texts_1=data_1, # type: ignore[arg-type]
|
||||
texts_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[ScoreResponse, ErrorResponse]:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.text_1,
|
||||
request.text_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
self.models.model_name(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
top_n = request.top_n if request.top_n > 0 else (
|
||||
len(documents)
|
||||
if isinstance(documents, list) else len(documents["content"]))
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self.models.model_name(),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> ScoreResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: list[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: Union[list[str], ScoreMultiModalParam],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]) if isinstance(
|
||||
documents, list) else RerankDocument(
|
||||
multi_modal=documents["content"][idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
196
vllm/entrypoints/openai/serving_tokenization.py
Normal file
196
vllm/entrypoints/openai/serving_tokenization.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
TokenizerInfoResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
tool_dicts = (None if request.tools is None else
|
||||
[tool.model_dump() for tool in request.tools])
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
tool_dicts=tool_dicts,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.prompt,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
input_ids: list[int] = []
|
||||
for engine_prompt in engine_prompts:
|
||||
self._log_inputs(request_id,
|
||||
engine_prompt,
|
||||
params=None,
|
||||
lora_request=lora_request)
|
||||
|
||||
if isinstance(engine_prompt,
|
||||
dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
token_strs=token_strs,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len)
|
||||
|
||||
async def create_detokenize(
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
params=None,
|
||||
lora_request=lora_request)
|
||||
|
||||
prompt_input = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
|
||||
async def get_tokenizer_info(
|
||||
self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
|
||||
"""Get comprehensive tokenizer information."""
|
||||
try:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
|
||||
return TokenizerInfoResponse(**info)
|
||||
except Exception as e:
|
||||
return self.create_error_response(
|
||||
f"Failed to get tokenizer info: {str(e)}")
|
||||
|
||||
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
|
||||
return RenderConfig(add_special_tokens=request.add_special_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
tokenizer: AnyTokenizer
|
||||
chat_template: Optional[str]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the tokenizer configuration."""
|
||||
return self._get_tokenizer_config()
|
||||
|
||||
def _get_tokenizer_config(self) -> dict[str, Any]:
|
||||
"""Get tokenizer configuration directly from the tokenizer object."""
|
||||
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
|
||||
|
||||
# Remove file path fields
|
||||
config.pop("vocab_file", None)
|
||||
config.pop("merges_file", None)
|
||||
|
||||
config = self._make_json_serializable(config)
|
||||
config["tokenizer_class"] = type(self.tokenizer).__name__
|
||||
if self.chat_template:
|
||||
config["chat_template"] = self.chat_template
|
||||
return config
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert any non-JSON-serializable objects to serializable format."""
|
||||
if hasattr(obj, "content"):
|
||||
return obj.content
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
136
vllm/entrypoints/openai/serving_transcription.py
Normal file
136
vllm/entrypoints/openai/serving_transcription.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
|
||||
TranslationResponseStreamChoice, TranslationStreamResponse)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
"""Handles transcription requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="transcribe",
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||
ErrorResponse]:
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranscriptionResponse,
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="transcription.chunk",
|
||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||
stream_response_class=TranscriptionStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate",
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
async def create_translation(
|
||||
self, audio_data: bytes, request: TranslationRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranslationResponse,
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self, request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
388
vllm/entrypoints/openai/speech_to_text.py
Normal file
388
vllm/entrypoints/openai/speech_to_text.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import cached_property
|
||||
from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage, ErrorResponse, RequestResponseMetadata,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, TranslationResponse,
|
||||
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
SpeechToTextRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse]
|
||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
task_type: Literal["transcribe", "translate"] = "transcribe",
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
self.task_type = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
model_config, task_type)
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsTranscription], model_cls)
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
to_language = self.model_cls.validate_language(request.to_language) \
|
||||
if request.to_language else None
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (self.asr_config.allow_audio_chunking
|
||||
and duration > self.asr_config.max_audio_clip_s)
|
||||
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||
prompts = []
|
||||
for chunk in chunks:
|
||||
# The model has control over the construction, as long as it
|
||||
# returns a valid PromptType.
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, duration
|
||||
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: SpeechToTextRequest,
|
||||
raw_request: Request,
|
||||
response_class: type[T],
|
||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||
) -> Union[T, AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ['text', 'json']:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`")
|
||||
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support LoRA for "
|
||||
f"{self.task_type.title()}.")
|
||||
|
||||
prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
|
||||
None]]] = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
# It will not display special tokens like <|startoftranscript|>
|
||||
request.prompt,
|
||||
params=sampling_params,
|
||||
lora_request=None)
|
||||
|
||||
list_result_generator = [
|
||||
self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
) for prompt in prompts
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(request, list_result_generator,
|
||||
request_id, request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
text = ""
|
||||
for result_generator in list_result_generator:
|
||||
async for op in result_generator:
|
||||
text += op.outputs[0].text
|
||||
|
||||
if self.task_type == "transcribe":
|
||||
# add usage in TranscriptionResponse.
|
||||
usage = {
|
||||
"type": "duration",
|
||||
# rounded up as per openAI specs
|
||||
"seconds": int(math.ceil(duration_s)),
|
||||
}
|
||||
final_response = cast(T, response_class(text=text,
|
||||
usage=usage))
|
||||
else:
|
||||
# no usage in response for translation task
|
||||
final_response = cast(
|
||||
T, response_class(text=text)) # type: ignore[call-arg]
|
||||
|
||||
return final_response
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
|
||||
response_stream_choice_class: Union[
|
||||
type[TranscriptionResponseStreamChoice],
|
||||
type[TranslationResponseStreamChoice]],
|
||||
stream_response_class: Union[type[TranscriptionStreamResponse],
|
||||
type[TranslationStreamResponse]],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = request.stream_include_usage \
|
||||
if request.stream_include_usage else False
|
||||
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
|
||||
try:
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if audio_tokens := self.model_cls.get_num_audio_tokens(
|
||||
audio_duration_s, self.asr_config,
|
||||
self.model_config):
|
||||
num_prompt_tokens += audio_tokens
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = stream_response_class(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
|
||||
final_usage_chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def _split_audio(self, audio_data: np.ndarray,
|
||||
sample_rate: int) -> list[np.ndarray]:
|
||||
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
||||
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < audio_data.shape[-1]:
|
||||
if i + chunk_size >= audio_data.shape[-1]:
|
||||
# handle last chunk
|
||||
chunks.append(audio_data[..., i:])
|
||||
break
|
||||
|
||||
# Find the best split point in the overlap region
|
||||
search_start = i + chunk_size - overlap_size
|
||||
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||
split_point = self._find_split_point(audio_data, search_start,
|
||||
search_end)
|
||||
|
||||
# Extract chunk up to the split point
|
||||
chunks.append(audio_data[..., i:split_point])
|
||||
i = split_point
|
||||
return chunks
|
||||
|
||||
def _find_split_point(self, wav: np.ndarray, start_idx: int,
|
||||
end_idx: int) -> int:
|
||||
"""Find the best point to split audio by
|
||||
looking for silence or low amplitude.
|
||||
Args:
|
||||
wav: Audio tensor [1, T]
|
||||
start_idx: Start index of search region
|
||||
end_idx: End index of search region
|
||||
Returns:
|
||||
Index of best splitting point
|
||||
"""
|
||||
segment = wav[start_idx:end_idx]
|
||||
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
min_energy_window = self.asr_config.min_energy_split_window_size
|
||||
assert min_energy_window is not None
|
||||
for i in range(0, len(segment) - min_energy_window, min_energy_window):
|
||||
window = segment[i:i + min_energy_window]
|
||||
energy = (window**2).mean()**0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
min_energy = energy
|
||||
return quietest_idx
|
||||
55
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
55
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
|
||||
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||
from .granite_tool_parser import GraniteToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .jamba_tool_parser import JambaToolParser
|
||||
from .kimi_k2_tool_parser import KimiK2ToolParser
|
||||
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .longcat_tool_parser import LongcatFlashToolParser
|
||||
from .minimax_tool_parser import MinimaxToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .openai_tool_parser import OpenAIToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||
from .qwen3xml_tool_parser import Qwen3XMLToolParser
|
||||
from .seed_oss_tool_parser import SeedOssToolParser
|
||||
from .step3_tool_parser import Step3ToolParser
|
||||
from .xlam_tool_parser import xLAMToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser",
|
||||
"ToolParserManager",
|
||||
"Granite20bFCToolParser",
|
||||
"GraniteToolParser",
|
||||
"Hermes2ProToolParser",
|
||||
"MistralToolParser",
|
||||
"Internlm2ToolParser",
|
||||
"Llama3JsonToolParser",
|
||||
"JambaToolParser",
|
||||
"Llama4PythonicToolParser",
|
||||
"LongcatFlashToolParser",
|
||||
"PythonicToolParser",
|
||||
"Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser",
|
||||
"DeepSeekV31ToolParser",
|
||||
"xLAMToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
"HunyuanA13BToolParser",
|
||||
"Glm4MoeModelToolParser",
|
||||
"Qwen3CoderToolParser",
|
||||
"Qwen3XMLToolParser",
|
||||
"SeedOssToolParser",
|
||||
"Step3ToolParser",
|
||||
"OpenAIToolParser",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user