init
This commit is contained in:
0
entrypoints/__init__.py
Normal file
0
entrypoints/__init__.py
Normal file
178
entrypoints/api_server.py
Normal file
178
entrypoints/api_server.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||
and simple performance benchmarks. It is not intended for production use.
|
||||
For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
return await _generate(request_dict, raw_request=request)
|
||||
|
||||
|
||||
@with_cancellation
|
||||
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\n").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
global app
|
||||
|
||||
app.root_path = args.root_path
|
||||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
args: Namespace,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
global engine
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (llm_engine
|
||||
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.API_SERVER))
|
||||
app.state.engine_client = engine
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(args: Namespace,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
**uvicorn_kwargs: Any) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
set_ulimit()
|
||||
|
||||
app = await init_app(args, llm_engine)
|
||||
assert engine is not None
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=None,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=parser.check_port, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The CA certificates file")
|
||||
parser.add_argument(
|
||||
"--enable-ssl-refresh",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Refresh SSL Context when SSL certificate files change")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser.add_argument("--log-level", type=str, default="debug")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
1299
entrypoints/chat_utils.py
Normal file
1299
entrypoints/chat_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
0
entrypoints/cli/__init__.py
Normal file
0
entrypoints/cli/__init__.py
Normal file
0
entrypoints/cli/benchmark/__init__.py
Normal file
0
entrypoints/cli/benchmark/__init__.py
Normal file
39
entrypoints/cli/benchmark/base.py
Normal file
39
entrypoints/cli/benchmark/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class BenchmarkSubcommandBase(CLISubcommand):
|
||||
""" The base class of subcommands for vllm bench. """
|
||||
|
||||
@property
|
||||
def help(self) -> str:
|
||||
"""The help message of the subcommand."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
"""Add the CLI arguments to the parser."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
args: The arguments to the command.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
parser = subparsers.add_parser(
|
||||
self.name,
|
||||
help=self.help,
|
||||
description=self.help,
|
||||
usage=f"vllm bench {self.name} [options]")
|
||||
self.add_cli_args(parser)
|
||||
return parser
|
||||
30
entrypoints/cli/benchmark/latency.py
Normal file
30
entrypoints/cli/benchmark/latency.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.latency import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
|
||||
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
|
||||
""" The `latency` subcommand for vllm bench. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "latency"
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def help(self) -> str:
|
||||
return "Benchmark the latency of a single batch of requests."
|
||||
|
||||
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [BenchmarkLatencySubcommand()]
|
||||
54
entrypoints/cli/benchmark/main.py
Normal file
54
entrypoints/cli/benchmark/main.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
import vllm.entrypoints.cli.benchmark.latency
|
||||
import vllm.entrypoints.cli.benchmark.serve
|
||||
import vllm.entrypoints.cli.benchmark.throughput
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
BENCHMARK_CMD_MODULES = [
|
||||
vllm.entrypoints.cli.benchmark.latency,
|
||||
vllm.entrypoints.cli.benchmark.serve,
|
||||
vllm.entrypoints.cli.benchmark.throughput,
|
||||
]
|
||||
|
||||
|
||||
class BenchmarkSubcommand(CLISubcommand):
|
||||
""" The `bench` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "bench"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
args.dispatch_function(args)
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
if args.bench_type in self.cmds:
|
||||
self.cmds[args.bench_type].validate(args)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
bench_parser = subparsers.add_parser(
|
||||
"bench",
|
||||
help="vLLM bench subcommand.",
|
||||
description="vLLM bench subcommand.",
|
||||
usage="vllm bench <bench_type> [options]")
|
||||
bench_subparsers = bench_parser.add_subparsers(required=True,
|
||||
dest="bench_type")
|
||||
self.cmds = {}
|
||||
for cmd_module in BENCHMARK_CMD_MODULES:
|
||||
new_cmds = cmd_module.cmd_init()
|
||||
for cmd in new_cmds:
|
||||
cmd.subparser_init(bench_subparsers).set_defaults(
|
||||
dispatch_function=cmd.cmd)
|
||||
self.cmds[cmd.name] = cmd
|
||||
return bench_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [BenchmarkSubcommand()]
|
||||
30
entrypoints/cli/benchmark/serve.py
Normal file
30
entrypoints/cli/benchmark/serve.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.serve import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
|
||||
class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
|
||||
""" The `serve` subcommand for vllm bench. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "serve"
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def help(self) -> str:
|
||||
return "Benchmark the online serving throughput."
|
||||
|
||||
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [BenchmarkServingSubcommand()]
|
||||
30
entrypoints/cli/benchmark/throughput.py
Normal file
30
entrypoints/cli/benchmark/throughput.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.throughput import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
|
||||
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
|
||||
""" The `throughput` subcommand for vllm bench. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "throughput"
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def help(self) -> str:
|
||||
return "Benchmark offline inference throughput."
|
||||
|
||||
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [BenchmarkThroughputSubcommand()]
|
||||
35
entrypoints/cli/collect_env.py
Normal file
35
entrypoints/cli/collect_env.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm.collect_env import main as collect_env_main
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class CollectEnvSubcommand(CLISubcommand):
|
||||
"""The `collect-env` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "collect-env"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Collect information about the environment."""
|
||||
collect_env_main()
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
collect_env_parser = subparsers.add_parser(
|
||||
"collect-env",
|
||||
help="Start collecting environment information.",
|
||||
description="Start collecting environment information.",
|
||||
usage="vllm collect-env")
|
||||
return collect_env_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [CollectEnvSubcommand()]
|
||||
65
entrypoints/cli/main.py
Normal file
65
entrypoints/cli/main.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# The CLI entrypoint to vLLM.
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import vllm.entrypoints.cli.benchmark.main
|
||||
import vllm.entrypoints.cli.collect_env
|
||||
import vllm.entrypoints.cli.openai
|
||||
import vllm.entrypoints.cli.run_batch
|
||||
import vllm.entrypoints.cli.serve
|
||||
import vllm.version
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
CMD_MODULES = [
|
||||
vllm.entrypoints.cli.openai,
|
||||
vllm.entrypoints.cli.serve,
|
||||
vllm.entrypoints.cli.benchmark.main,
|
||||
vllm.entrypoints.cli.collect_env,
|
||||
vllm.entrypoints.cli.run_batch,
|
||||
]
|
||||
|
||||
|
||||
def register_signal_handlers():
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTSTP, signal_handler)
|
||||
|
||||
|
||||
def main():
|
||||
cli_env_setup()
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM CLI",
|
||||
epilog=VLLM_SUBCMD_PARSER_EPILOG,
|
||||
)
|
||||
parser.add_argument('-v',
|
||||
'--version',
|
||||
action='version',
|
||||
version=vllm.version.__version__)
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
cmds = {}
|
||||
for cmd_module in CMD_MODULES:
|
||||
new_cmds = cmd_module.cmd_init()
|
||||
for cmd in new_cmds:
|
||||
cmd.subparser_init(subparsers).set_defaults(
|
||||
dispatch_function=cmd.cmd)
|
||||
cmds[cmd.name] = cmd
|
||||
args = parser.parse_args()
|
||||
if args.subparser in cmds:
|
||||
cmds[args.subparser].validate(args)
|
||||
|
||||
if hasattr(args, "dispatch_function"):
|
||||
args.dispatch_function(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
205
entrypoints/cli/openai.py
Normal file
205
entrypoints/cli/openai.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Commands that act as an interactive OpenAI API client
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def _register_signal_handlers():
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTSTP, signal_handler)
|
||||
|
||||
|
||||
def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
|
||||
_register_signal_handlers()
|
||||
|
||||
base_url = args.url
|
||||
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
|
||||
openai_client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
if args.model_name:
|
||||
model_name = args.model_name
|
||||
else:
|
||||
available_models = openai_client.models.list()
|
||||
model_name = available_models.data[0].id
|
||||
|
||||
print(f"Using model: {model_name}")
|
||||
|
||||
return model_name, openai_client
|
||||
|
||||
|
||||
def chat(system_prompt: Optional[str], model_name: str,
|
||||
client: OpenAI) -> None:
|
||||
conversation: list[ChatCompletionMessageParam] = []
|
||||
if system_prompt is not None:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
while True:
|
||||
try:
|
||||
input_message = input("> ")
|
||||
except EOFError:
|
||||
return
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
chat_completion = client.chat.completions.create(model=model_name,
|
||||
messages=conversation)
|
||||
|
||||
response_message = chat_completion.choices[0].message
|
||||
output = response_message.content
|
||||
|
||||
conversation.append(response_message) # type: ignore
|
||||
print(output)
|
||||
|
||||
|
||||
def _add_query_options(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="url of the running OpenAI-Compatible RESTful API server")
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("The model name used in prompt completion, default to "
|
||||
"the first model in list models API call."))
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"API key for OpenAI services. If provided, this api key "
|
||||
"will overwrite the api key obtained through environment variables."
|
||||
))
|
||||
return parser
|
||||
|
||||
|
||||
class ChatCommand(CLISubcommand):
|
||||
"""The `chat` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "chat"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
model_name, client = _interactive_cli(args)
|
||||
system_prompt = args.system_prompt
|
||||
conversation: list[ChatCompletionMessageParam] = []
|
||||
|
||||
if system_prompt is not None:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
if args.quick:
|
||||
conversation.append({"role": "user", "content": args.quick})
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=model_name, messages=conversation)
|
||||
print(chat_completion.choices[0].message.content)
|
||||
return
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
while True:
|
||||
try:
|
||||
input_message = input("> ")
|
||||
except EOFError:
|
||||
return
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=model_name, messages=conversation)
|
||||
|
||||
response_message = chat_completion.choices[0].message
|
||||
output = response_message.content
|
||||
|
||||
conversation.append(response_message) # type: ignore
|
||||
print(output)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
chat_parser = subparsers.add_parser(
|
||||
"chat",
|
||||
help="Generate chat completions via the running API server.",
|
||||
description="Generate chat completions via the running API server.",
|
||||
usage="vllm chat [options]")
|
||||
_add_query_options(chat_parser)
|
||||
chat_parser.add_argument(
|
||||
"--system-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("The system prompt to be added to the chat template, "
|
||||
"used for models that support system prompts."))
|
||||
chat_parser.add_argument("-q",
|
||||
"--quick",
|
||||
type=str,
|
||||
metavar="MESSAGE",
|
||||
help=("Send a single prompt as MESSAGE "
|
||||
"and print the response, then exit."))
|
||||
return chat_parser
|
||||
|
||||
|
||||
class CompleteCommand(CLISubcommand):
|
||||
"""The `complete` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "complete"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
model_name, client = _interactive_cli(args)
|
||||
|
||||
if args.quick:
|
||||
completion = client.completions.create(model=model_name,
|
||||
prompt=args.quick)
|
||||
print(completion.choices[0].text)
|
||||
return
|
||||
|
||||
print("Please enter prompt to complete:")
|
||||
while True:
|
||||
input_prompt = input("> ")
|
||||
completion = client.completions.create(model=model_name,
|
||||
prompt=input_prompt)
|
||||
output = completion.choices[0].text
|
||||
print(output)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
complete_parser = subparsers.add_parser(
|
||||
"complete",
|
||||
help=("Generate text completions based on the given prompt "
|
||||
"via the running API server."),
|
||||
description=("Generate text completions based on the given prompt "
|
||||
"via the running API server."),
|
||||
usage="vllm complete [options]")
|
||||
_add_query_options(complete_parser)
|
||||
complete_parser.add_argument(
|
||||
"-q",
|
||||
"--quick",
|
||||
type=str,
|
||||
metavar="PROMPT",
|
||||
help=
|
||||
"Send a single prompt and print the completion output, then exit.")
|
||||
return complete_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ChatCommand(), CompleteCommand()]
|
||||
62
entrypoints/cli/run_batch.py
Normal file
62
entrypoints/cli/run_batch.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.logger import logger
|
||||
from vllm.entrypoints.openai.run_batch import main as run_batch_main
|
||||
from vllm.entrypoints.openai.run_batch import make_arg_parser
|
||||
from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
|
||||
show_filtered_argument_or_group_from_help)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
class RunBatchSubcommand(CLISubcommand):
|
||||
"""The `run-batch` subcommand for vLLM CLI."""
|
||||
|
||||
def __init__(self):
|
||||
self.name = "run-batch"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server.
|
||||
# LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(run_batch_main(args))
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
run_batch_parser = subparsers.add_parser(
|
||||
"run-batch",
|
||||
help="Run batch prompts and write results to file.",
|
||||
description=(
|
||||
"Run batch prompts using vLLM's OpenAI-compatible API.\n"
|
||||
"Supports local or HTTP input/output files."),
|
||||
usage=
|
||||
"vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
|
||||
)
|
||||
run_batch_parser = make_arg_parser(run_batch_parser)
|
||||
show_filtered_argument_or_group_from_help(run_batch_parser,
|
||||
"run-batch")
|
||||
run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG
|
||||
return run_batch_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [RunBatchSubcommand()]
|
||||
328
entrypoints/cli/serve.py
Normal file
328
entrypoints/cli/serve.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
|
||||
setup_server)
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
|
||||
show_filtered_argument_or_group_from_help)
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.core_client import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
||||
CoreEngineActorManager, EngineZmqAddresses,
|
||||
get_engine_client_zmq_addr,
|
||||
wait_for_completion_or_failure,
|
||||
wait_for_engine_startup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServeSubcommand(CLISubcommand):
|
||||
"""The `serve` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "serve"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
# If model is specified in CLI (as positional arg), it takes precedence
|
||||
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||
args.model = args.model_tag
|
||||
|
||||
if args.headless or args.api_server_count < 1:
|
||||
run_headless(args)
|
||||
elif args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
# Single API server (this process).
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
serve_parser = subparsers.add_parser(
|
||||
"serve",
|
||||
help="Start the vLLM OpenAI Compatible API server.",
|
||||
description="Start the vLLM OpenAI Compatible API server.",
|
||||
usage="vllm serve [model_tag] [options]")
|
||||
serve_parser.add_argument("model_tag",
|
||||
type=str,
|
||||
nargs='?',
|
||||
help="The model tag to serve "
|
||||
"(optional if specified in config)")
|
||||
serve_parser.add_argument(
|
||||
"--headless",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Run in headless mode. See multi-node data parallel "
|
||||
"documentation for more details.")
|
||||
serve_parser.add_argument(
|
||||
'--data-parallel-start-rank',
|
||||
'-dpr',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Starting data parallel rank for secondary nodes.')
|
||||
serve_parser.add_argument('--api-server-count',
|
||||
'-asc',
|
||||
type=int,
|
||||
default=1,
|
||||
help='How many API server processes to run.')
|
||||
serve_parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default='',
|
||||
required=False,
|
||||
help="Read CLI options from a config file."
|
||||
"Must be a YAML with the following options:"
|
||||
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference"
|
||||
)
|
||||
|
||||
serve_parser = make_arg_parser(serve_parser)
|
||||
show_filtered_argument_or_group_from_help(serve_parser, "serve")
|
||||
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG
|
||||
return serve_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ServeSubcommand()]
|
||||
|
||||
|
||||
def run_headless(args: argparse.Namespace):
|
||||
|
||||
if args.api_server_count > 1:
|
||||
raise ValueError("api_server_count can't be set in headless mode")
|
||||
|
||||
# Create the EngineConfig.
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("Headless mode is only supported for V1")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
if local_engine_count <= 0:
|
||||
raise ValueError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
logger.debug("Received %d signal.", signum)
|
||||
raise SystemExit
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.", local_engine_count, handshake_address)
|
||||
|
||||
# Create the engines.
|
||||
engine_manager = CoreEngineProcManager(
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=args.data_parallel_start_rank,
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
on_head_node=False,
|
||||
handshake_address=handshake_address,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
|
||||
try:
|
||||
engine_manager.join_first()
|
||||
finally:
|
||||
logger.info("Shutting down.")
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
assert not args.headless
|
||||
num_api_servers = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if num_api_servers > 1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("api_server_count > 1 is only supported for V1")
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||
"with api_server_count > 1")
|
||||
|
||||
if model_config.is_multimodal_model and not (
|
||||
model_config.disable_mm_preprocessor_cache):
|
||||
logger.warning(
|
||||
"Multi-model preprocessor cache will be disabled for"
|
||||
" api_server_count > 1")
|
||||
model_config.disable_mm_preprocessor_cache = True
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
assert parallel_config.data_parallel_rank == 0
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_only = local_engine_count == dp_size
|
||||
|
||||
# Set up input and output addresses.
|
||||
input_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
output_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=input_addresses,
|
||||
outputs=output_addresses,
|
||||
)
|
||||
|
||||
# Set up coordinator for dp > 1.
|
||||
coordinator = None
|
||||
stats_update_address = None
|
||||
if dp_size > 1:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
stats_update_address = coordinator.get_stats_publish_address()
|
||||
logger.info("Started DP Coordinator process (PID: %d)",
|
||||
coordinator.proc.pid)
|
||||
|
||||
if parallel_config.data_parallel_backend == "ray":
|
||||
logger.info("Starting ray-based data parallel backend")
|
||||
|
||||
engine_actor_manager = CoreEngineActorManager(
|
||||
vllm_config=vllm_config,
|
||||
addresses=addresses,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
# Start API servers using the manager
|
||||
api_server_manager = APIServerProcessManager(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=input_addresses,
|
||||
output_addresses=output_addresses,
|
||||
stats_update_address=stats_update_address)
|
||||
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=engine_actor_manager,
|
||||
coordinator=coordinator)
|
||||
return
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||
bind=True) as handshake_socket:
|
||||
|
||||
# Start local engines.
|
||||
if not local_engine_count:
|
||||
local_engine_manager = None
|
||||
else:
|
||||
local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
handshake_address=handshake_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=0,
|
||||
local_start_index=0)
|
||||
|
||||
# Start API servers using the manager
|
||||
api_server_manager = APIServerProcessManager(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=input_addresses,
|
||||
output_addresses=output_addresses,
|
||||
stats_update_address=stats_update_address)
|
||||
|
||||
# Wait for engine handshakes to complete.
|
||||
core_engines = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(dp_size)
|
||||
]
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
core_engines,
|
||||
parallel_config,
|
||||
vllm_config.cache_config,
|
||||
local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(listen_address,
|
||||
sock,
|
||||
args,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
|
||||
# Add process-specific prefix to stdout and stderr.
|
||||
from multiprocessing import current_process
|
||||
process_name = current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
uvloop.run(
|
||||
run_server_worker(listen_address, sock, args, client_config,
|
||||
**uvicorn_kwargs))
|
||||
25
entrypoints/cli/types.py
Normal file
25
entrypoints/cli/types.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class CLISubcommand:
|
||||
"""Base class for CLI argument handlers."""
|
||||
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
raise NotImplementedError("Subclasses should implement this method")
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
# No validation by default
|
||||
pass
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
raise NotImplementedError("Subclasses should implement this method")
|
||||
147
entrypoints/launcher.py
Normal file
147
entrypoints/launcher.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import socket
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(app: FastAPI,
|
||||
sock: Optional[socket.socket],
|
||||
enable_ssl_refresh: bool = False,
|
||||
**uvicorn_kwargs: Any):
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
methods = getattr(route, "methods", None)
|
||||
path = getattr(route, "path", None)
|
||||
|
||||
if methods is None or path is None:
|
||||
continue
|
||||
|
||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
config.load()
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
watchdog_task = loop.create_task(
|
||||
watchdog_loop(server, app.state.engine_client))
|
||||
server_task = loop.create_task(
|
||||
server.serve(sockets=[sock] if sock else None))
|
||||
|
||||
ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
|
||||
ssl_context=config.ssl,
|
||||
key_path=config.ssl_keyfile,
|
||||
cert_path=config.ssl_certfile,
|
||||
ca_path=config.ssl_ca_certs)
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
watchdog_task.cancel()
|
||||
if ssl_cert_refresher:
|
||||
ssl_cert_refresher.stop()
|
||||
|
||||
async def dummy_shutdown() -> None:
|
||||
pass
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
return dummy_shutdown()
|
||||
except asyncio.CancelledError:
|
||||
port = uvicorn_kwargs["port"]
|
||||
process = find_process_using_port(port)
|
||||
if process is not None:
|
||||
logger.debug(
|
||||
"port %s is used by process %s launched with command:\n%s",
|
||||
port, process, " ".join(process.cmdline()))
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
finally:
|
||||
watchdog_task.cancel()
|
||||
|
||||
|
||||
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
|
||||
"""
|
||||
# Watchdog task that runs in the background, checking
|
||||
# for error state in the engine. Needed to trigger shutdown
|
||||
# if an exception arises is StreamingResponse() generator.
|
||||
"""
|
||||
VLLM_WATCHDOG_TIME_S = 5.0
|
||||
while True:
|
||||
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
|
||||
terminate_if_errored(server, engine)
|
||||
|
||||
|
||||
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
|
||||
"""
|
||||
See discussions here on shutting down a uvicorn server
|
||||
https://github.com/encode/uvicorn/discussions/1103
|
||||
In this case we cannot await the server shutdown here
|
||||
because handler must first return to close the connection
|
||||
for this request.
|
||||
"""
|
||||
engine_errored = engine.errored and not engine.is_running
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
|
||||
server.should_exit = True
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""
|
||||
VLLM V1 AsyncLLM catches exceptions and returns
|
||||
only two types: EngineGenerateError and EngineDeadError.
|
||||
|
||||
EngineGenerateError is raised by the per request generate()
|
||||
method. This error could be request specific (and therefore
|
||||
recoverable - e.g. if there is an error in input processing).
|
||||
|
||||
EngineDeadError is raised by the background output_handler
|
||||
method. This error is global and therefore not recoverable.
|
||||
|
||||
We register these @app.exception_handlers to return nice
|
||||
responses to the end user if they occur and shut down if needed.
|
||||
See https://fastapi.tiangolo.com/tutorial/handling-errors/
|
||||
for more details on how exception handlers work.
|
||||
|
||||
If an exception is encountered in a StreamingResponse
|
||||
generator, the exception is not raised, since we already sent
|
||||
a 200 status. Rather, we send an error message as the next chunk.
|
||||
Since the exception is not raised, this means that the server
|
||||
will not automatically shut down. Instead, we use the watchdog
|
||||
background task for check for errored state.
|
||||
"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
@app.exception_handler(AsyncEngineDeadError)
|
||||
@app.exception_handler(MQEngineDeadError)
|
||||
@app.exception_handler(EngineDeadError)
|
||||
@app.exception_handler(EngineGenerateError)
|
||||
async def runtime_exception_handler(request: Request, __):
|
||||
terminate_if_errored(
|
||||
server=server,
|
||||
engine=request.app.state.engine_client,
|
||||
)
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
1545
entrypoints/llm.py
Normal file
1545
entrypoints/llm.py
Normal file
File diff suppressed because it is too large
Load Diff
50
entrypoints/logger.py
Normal file
50
entrypoints/logger.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RequestLogger:
|
||||
|
||||
def __init__(self, *, max_log_len: Optional[int]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_log_len = max_log_len
|
||||
|
||||
def log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
prompt_embeds: Optional[torch.Tensor],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if prompt is not None:
|
||||
prompt = prompt[:max_log_len]
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = prompt_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"prompt_embeds shape: %s, "
|
||||
"lora_request: %s, prompt_adapter_request: %s.", request_id,
|
||||
prompt, params, prompt_token_ids,
|
||||
prompt_embeds.shape if prompt_embeds is not None else None,
|
||||
lora_request, prompt_adapter_request)
|
||||
0
entrypoints/openai/__init__.py
Normal file
0
entrypoints/openai/__init__.py
Normal file
1387
entrypoints/openai/api_server.py
Normal file
1387
entrypoints/openai/api_server.py
Normal file
File diff suppressed because it is too large
Load Diff
315
entrypoints/openai/cli_args.py
Normal file
315
entrypoints/openai/cli_args.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union, get_args
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: list[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, '']: # Skip if item is None or empty string
|
||||
continue
|
||||
if '=' in item and ',' not in item: # Old format: name=path
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
adapter_list: list[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
setattr(namespace, self.dest, adapter_list)
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument("--host",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help="Host name.")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number.")
|
||||
parser.add_argument(
|
||||
"--uvicorn-log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
||||
help="Log level for uvicorn.")
|
||||
parser.add_argument("--disable-uvicorn-access-log",
|
||||
action="store_true",
|
||||
help="Disable uvicorn access log.")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="Allow credentials.")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="Allowed origins.")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="Allowed methods.")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="Allowed headers.")
|
||||
parser.add_argument("--api-key",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help="If provided, the server will require this key "
|
||||
"to be presented in the header.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): ``'name=path'`` "
|
||||
"Example (new format): "
|
||||
"``{\"name\": \"name\", \"path\": \"lora_path\", "
|
||||
"\"base_model_name\": \"id\"}``")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model.")
|
||||
parser.add_argument(
|
||||
'--chat-template-content-format',
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=get_args(ChatTemplateContentFormatOption),
|
||||
help='The format to render message content within a chat template.'
|
||||
'\n\n'
|
||||
'* "string" will render the content as a string. '
|
||||
'Example: ``"Hello World"``\n'
|
||||
'* "openai" will render the content as a list of dictionaries, '
|
||||
'similar to OpenAI schema. '
|
||||
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
|
||||
parser.add_argument("--response-role",
|
||||
type=optional_type(str),
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"``request.add_generation_prompt=true``.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help="The file path to the SSL key file.")
|
||||
parser.add_argument("--ssl-certfile",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help="The file path to the SSL cert file.")
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help="The CA certificates file.")
|
||||
parser.add_argument(
|
||||
"--enable-ssl-refresh",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Refresh SSL Context when SSL certificate files change")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--middleware",
|
||||
type=optional_type(str),
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional ASGI middleware to apply to the app. "
|
||||
"We accept multiple --middleware arguments. "
|
||||
"The value should be an import path. "
|
||||
"If a function is provided, vLLM will add it to the server "
|
||||
"using ``@app.middleware('http')``. "
|
||||
"If a class is provided, vLLM will add it to the server "
|
||||
"using ``app.add_middleware()``. ")
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When ``--max-logprobs`` is specified, represents single tokens "
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified.")
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
parser.add_argument(
|
||||
"--enable-request-id-headers",
|
||||
action="store_true",
|
||||
help="If specified, API server will add X-Request-Id header to "
|
||||
"responses. Caution: this hurts performance at high QPS.")
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable auto tool choice for supported models. Use "
|
||||
"``--tool-call-parser`` to specify which parser to use.")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||
"--tool-parser-plugin",
|
||||
default=None,
|
||||
help=
|
||||
"Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for ``--enable-auto-tool-choice``.")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default="",
|
||||
help=
|
||||
"Special the tool parser plugin write to parse the model-generated tool"
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in ``--tool-call-parser``.")
|
||||
|
||||
parser.add_argument(
|
||||
"--log-config-file",
|
||||
type=str,
|
||||
default=envs.VLLM_LOGGING_CONFIG_PATH,
|
||||
help="Path to logging config JSON file for both vllm and uvicorn",
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
' The default of None means unlimited.')
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-fastapi-docs",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.")
|
||||
parser.add_argument(
|
||||
"--enable-server-load-tracking",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"If set to True, enable tracking server_load_metrics in the app state."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
if args.enable_prompt_embeds and args.enable_prompt_adapter:
|
||||
raise ValueError(
|
||||
"Cannot use prompt embeds and prompt adapter at the same time.")
|
||||
|
||||
|
||||
def log_non_default_args(args: argparse.Namespace):
|
||||
non_default_args = {}
|
||||
parser = make_arg_parser(FlexibleArgumentParser())
|
||||
for arg, default in vars(parser.parse_args([])).items():
|
||||
if default != getattr(args, arg):
|
||||
non_default_args[arg] = getattr(args, arg)
|
||||
logger.info("non-default args: %s", non_default_args)
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server")
|
||||
return make_arg_parser(parser_for_docs)
|
||||
90
entrypoints/openai/logits_processors.py
Normal file
90
entrypoints/openai/logits_processors.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import lru_cache, partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
"""Logits processor for constraining generated tokens to a
|
||||
specific set of token ids."""
|
||||
|
||||
def __init__(self, allowed_ids: Iterable[int]):
|
||||
self.allowed_ids: Optional[list[int]] = list(allowed_ids)
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, token_ids: list[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.mask is None:
|
||||
self.mask = torch.ones((logits.shape[-1], ),
|
||||
dtype=torch.bool,
|
||||
device=logits.device)
|
||||
self.mask[self.allowed_ids] = False
|
||||
self.allowed_ids = None
|
||||
logits.masked_fill_(self.mask, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_allowed_token_ids_logits_processor(
|
||||
allowed_token_ids: frozenset[int],
|
||||
vocab_size: int,
|
||||
) -> LogitsProcessor:
|
||||
if not allowed_token_ids:
|
||||
raise ValueError("Empty allowed_token_ids provided")
|
||||
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains "
|
||||
"out-of-vocab token id")
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
logit_bias: dict[int, float],
|
||||
token_ids: list[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]],
|
||||
allowed_token_ids: Optional[list[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> list[LogitsProcessor]:
|
||||
logits_processors: list[LogitsProcessor] = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
clamped_logit_bias: dict[int, float] = {
|
||||
int(token_id): min(100.0, max(-100.0, bias))
|
||||
for token_id, bias in logit_bias.items()
|
||||
}
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Found token_id in logit_bias that is not "
|
||||
"an integer or string representing an integer") from exc
|
||||
|
||||
# Check if token_id is within the vocab size
|
||||
for token_id, bias in clamped_logit_bias.items():
|
||||
if token_id < 0 or token_id >= len(tokenizer):
|
||||
raise ValueError(f"token_id {token_id} in logit_bias contains "
|
||||
"out-of-vocab token id")
|
||||
|
||||
logits_processors.append(
|
||||
partial(logit_bias_logits_processor, clamped_logit_bias))
|
||||
|
||||
if allowed_token_ids is not None:
|
||||
logits_processors.append(
|
||||
_get_allowed_token_ids_logits_processor(
|
||||
frozenset(allowed_token_ids), len(tokenizer)))
|
||||
|
||||
return logits_processors
|
||||
1913
entrypoints/openai/protocol.py
Normal file
1913
entrypoints/openai/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
463
entrypoints/openai/run_batch.py
Normal file
463
entrypoints/openai/run_batch.py
Normal file
@@ -0,0 +1,463 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from collections.abc import Awaitable
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
from typing import Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger, logger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
RerankResponse, ScoreResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help=
|
||||
"The path or url to a single input file. Currently supports local file "
|
||||
"paths, or the http protocol (http or https). If a URL is specified, "
|
||||
"the file should be available via HTTP GET.")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The path or url to a single output file. Currently supports "
|
||||
"local file paths, or web (http or https) urls. If a URL is specified,"
|
||||
" the file should be available via HTTP PUT.")
|
||||
parser.add_argument(
|
||||
"--output-tmp-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory to store the output file before uploading it "
|
||||
"to the output URL.",
|
||||
)
|
||||
parser.add_argument("--response-role",
|
||||
type=optional_type(str),
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=True`.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument("--enable-metrics",
|
||||
action="store_true",
|
||||
help="Enable Prometheus metrics")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="URL to the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port number for the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible batch runner.")
|
||||
return make_arg_parser(parser).parse_args()
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
class BatchProgressTracker:
|
||||
|
||||
def __init__(self):
|
||||
self._total = 0
|
||||
self._pbar: Optional[tqdm] = None
|
||||
|
||||
def submitted(self):
|
||||
self._total += 1
|
||||
|
||||
def completed(self):
|
||||
if self._pbar:
|
||||
self._pbar.update()
|
||||
|
||||
def pbar(self) -> tqdm:
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
self._pbar = tqdm(total=self._total,
|
||||
unit="req",
|
||||
desc="Running batch",
|
||||
mininterval=5,
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT)
|
||||
return self._pbar
|
||||
|
||||
|
||||
async def read_file(path_or_url: str) -> str:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, \
|
||||
session.get(path_or_url) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
with open(path_or_url, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def write_local_file(output_path: str,
|
||||
batch_outputs: list[BatchRequestOutput]) -> None:
|
||||
"""
|
||||
Write the responses to a local file.
|
||||
output_path: The path to write the responses to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
"""
|
||||
# We should make this async, but as long as run_batch runs as a
|
||||
# standalone program, blocking the event loop won't effect performance.
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=f)
|
||||
|
||||
|
||||
async def upload_data(output_url: str, data_or_file: str,
|
||||
from_file: bool) -> None:
|
||||
"""
|
||||
Upload a local file to a URL.
|
||||
output_url: The URL to upload the file to.
|
||||
data_or_file: Either the data to upload or the path to the file to upload.
|
||||
from_file: If True, data_or_file is the path to the file to upload.
|
||||
"""
|
||||
# Timeout is a common issue when uploading large files.
|
||||
# We retry max_retries times before giving up.
|
||||
max_retries = 5
|
||||
# Number of seconds to wait before retrying.
|
||||
delay = 5
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# We increase the timeout to 1000 seconds to allow
|
||||
# for large files (default is 300).
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||||
total=1000)) as session:
|
||||
if from_file:
|
||||
with open(data_or_file, "rb") as file:
|
||||
async with session.put(output_url,
|
||||
data=file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to upload file.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}")
|
||||
else:
|
||||
async with session.put(output_url,
|
||||
data=data_or_file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to upload data.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}")
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries:
|
||||
logger.error(
|
||||
f"Failed to upload data (attempt {attempt}). "
|
||||
f"Error message: {str(e)}.\nRetrying in {delay} seconds..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
raise Exception(f"Failed to upload data (attempt {attempt}). "
|
||||
f"Error message: {str(e)}.") from e
|
||||
|
||||
|
||||
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput],
|
||||
output_tmp_dir: str) -> None:
|
||||
"""
|
||||
Write batch_outputs to a file or upload to a URL.
|
||||
path_or_url: The path or URL to write batch_outputs to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
output_tmp_dir: The directory to store the output file before uploading it
|
||||
to the output URL.
|
||||
"""
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
if output_tmp_dir is None:
|
||||
logger.info("Writing outputs to memory buffer")
|
||||
output_buffer = StringIO()
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=output_buffer)
|
||||
output_buffer.seek(0)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(
|
||||
path_or_url,
|
||||
output_buffer.read().strip().encode("utf-8"),
|
||||
from_file=False,
|
||||
)
|
||||
else:
|
||||
# Write responses to a temporary file and then upload it to the URL.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=output_tmp_dir,
|
||||
prefix="tmp_batch_output_",
|
||||
suffix=".jsonl",
|
||||
) as f:
|
||||
logger.info("Writing outputs to temporary local file %s",
|
||||
f.name)
|
||||
await write_local_file(f.name, batch_outputs)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(path_or_url, f.name, from_file=True)
|
||||
else:
|
||||
logger.info("Writing outputs to local file %s", path_or_url)
|
||||
await write_local_file(path_or_url, batch_outputs)
|
||||
|
||||
|
||||
def make_error_request_output(request: BatchRequestInput,
|
||||
error_msg: str) -> BatchRequestOutput:
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=error_msg,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
|
||||
async def make_async_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
|
||||
return make_error_request_output(request, error_msg)
|
||||
|
||||
|
||||
async def run_request(serving_engine_func: Callable,
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
|
||||
RerankResponse),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=response.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
batch_output = make_error_request_output(
|
||||
request, error_msg="Request must not be sent in stream mode")
|
||||
|
||||
tracker.completed()
|
||||
return batch_output
|
||||
|
||||
|
||||
async def main(args):
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
||||
|
||||
model_config = await engine.get_model_config()
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
)
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embed" else None
|
||||
openai_serving_scores = (ServingScores(
|
||||
engine,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.task == "score" else None)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: list[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
chat_handler_fn = openai_serving_chat.create_chat_completion if \
|
||||
openai_serving_chat is not None else None
|
||||
if chat_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=
|
||||
"The model does not support Chat Completions API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(chat_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
embed_handler_fn = openai_serving_embedding.create_embedding if \
|
||||
openai_serving_embedding is not None else None
|
||||
if embed_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Embeddings API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/score"):
|
||||
score_handler_fn = openai_serving_scores.create_score if \
|
||||
openai_serving_scores is not None else None
|
||||
if score_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Scores API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/rerank"):
|
||||
rerank_handler_fn = openai_serving_scores.do_rerank if \
|
||||
openai_serving_scores is not None else None
|
||||
if rerank_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Rerank API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(rerank_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /score, /rerank ."
|
||||
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||
"score/rerank versions.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
await write_file(args.output_file, responses, args.output_tmp_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(main(args))
|
||||
1221
entrypoints/openai/serving_chat.py
Normal file
1221
entrypoints/openai/serving_chat.py
Normal file
File diff suppressed because it is too large
Load Diff
160
entrypoints/openai/serving_classification.py
Normal file
160
entrypoints/openai/serving_classification.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
ErrorResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationMixin(OpenAIServing):
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
if isinstance(ctx.request.input, str) and not ctx.request.input:
|
||||
return self.create_error_response(
|
||||
"Input cannot be empty for classification",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
(
|
||||
ctx.lora_request,
|
||||
ctx.prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
||||
ctx.lora_request)
|
||||
|
||||
if ctx.prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"Prompt adapter is not supported for classification models"
|
||||
)
|
||||
|
||||
(
|
||||
ctx.request_prompts,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
ctx.tokenizer,
|
||||
ctx.request.input,
|
||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
probs = classify_res.probs
|
||||
predicted_index = int(np.argmax(probs))
|
||||
label = getattr(self.model_config.hf_config, "id2label",
|
||||
{}).get(predicted_index)
|
||||
|
||||
item = ClassificationData(
|
||||
index=idx,
|
||||
label=label,
|
||||
probs=probs,
|
||||
num_classes=len(probs),
|
||||
)
|
||||
|
||||
items.append(item)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ClassificationResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class ServingClassification(ClassificationMixin):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = (f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request)}")
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
592
entrypoints/openai/serving_completion.py
Normal file
592
entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,592 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs,
|
||||
is_text_tokens_prompt)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info("Using default completion sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"Echo is unsupported with prompt embeds.")
|
||||
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
request_prompts, engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
# Mypy does not infer that engine_prompt will have only one of
|
||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||
# these as Union[object, the expected type], where it infers
|
||||
# object if engine_prompt is a subclass of one of the
|
||||
# typeddicts that defines both keys. Worse, because of
|
||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||
# infer the type of engine_prompt correctly because of the
|
||||
# enumerate. So we need an unnecessary cast here.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if is_embeds_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_embeds"])
|
||||
elif is_tokens_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
default_max_tokens = self.max_model_len - input_length
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. Noting that best_of is only supported in V0. In addition,
|
||||
# we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
request_prompt = request_prompts[i]
|
||||
if is_text_tokens_prompt(request_prompt):
|
||||
final_res.prompt = request_prompt["prompt"]
|
||||
else:
|
||||
final_res.prompt = None
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
|
||||
stream_options = request.stream_options
|
||||
if stream_options:
|
||||
include_usage = stream_options.include_usage
|
||||
include_continuous_usage = include_usage and \
|
||||
stream_options.continuous_usage_stats
|
||||
else:
|
||||
include_usage, include_continuous_usage = False, False
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[dict[
|
||||
int, Logprob]]]]
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and not has_echoed[i]:
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids, *output.token_ids
|
||||
]
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if not delta_text and not delta_token_ids \
|
||||
and not previous_num_tokens[i]:
|
||||
# Chunked prefill case, don't return empty chunks
|
||||
continue
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
return_as_token_id=request.
|
||||
return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
])
|
||||
if include_continuous_usage:
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
response_json = chunk.model_dump_json(exclude_unset=False)
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
total_prompt_tokens = sum(num_prompt_tokens)
|
||||
total_completion_tokens = sum(previous_num_tokens)
|
||||
final_usage_info = UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens)
|
||||
|
||||
if include_usage:
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=final_usage_info,
|
||||
)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=False, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = final_usage_info
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
self,
|
||||
final_res_batch: list[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
|
||||
for final_res in final_res_batch:
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[dict[int,
|
||||
Logprob]]]]
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo:
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
else:
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
kv_transfer_params=final_res_batch[0].kv_transfer_params)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
return_as_token_id: Optional[bool] = None,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: list[int] = []
|
||||
out_token_logprobs: list[Optional[float]] = []
|
||||
out_tokens: list[str] = []
|
||||
out_top_logprobs: list[Optional[dict[str, float]]] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
should_return_as_token_id = return_as_token_id if \
|
||||
return_as_token_id is not None else self.return_tokens_as_token_ids
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if should_return_as_token_id:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
})
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
201
entrypoints/openai/serving_embedding.py
Normal file
201
entrypoints/openai/serving_embedding.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[list[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.embedding
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
class EmbeddingMixin(OpenAIServing):
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
(
|
||||
ctx.lora_request,
|
||||
ctx.prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
||||
)
|
||||
|
||||
if ctx.prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
ctx.request_prompts,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template
|
||||
or ctx.chat_template,
|
||||
chat_template_content_format=ctx.
|
||||
chat_template_content_format,
|
||||
# In embedding requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
(ctx.request_prompts,
|
||||
ctx.engine_prompts) = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.input,
|
||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
items: list[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
embedding_res = EmbeddingRequestOutput.from_base(final_res)
|
||||
|
||||
item = EmbeddingResponseData(
|
||||
index=idx,
|
||||
embedding=_get_embedding(embedding_res.outputs,
|
||||
ctx.request.encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
request_id_prefix = "embd"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = (f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request)}")
|
||||
|
||||
ctx = EmbeddingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _validate_request(
|
||||
self,
|
||||
ctx: ServeContext[EmbeddingRequest],
|
||||
) -> Optional[ErrorResponse]:
|
||||
if error := super()._validate_request(ctx):
|
||||
return error
|
||||
|
||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||
|
||||
pooling_params = ctx.request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify(self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return None
|
||||
986
entrypoints/openai/serving_engine.py
Normal file
986
entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,986 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
|
||||
Sequence)
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypeVar, Union, cast, overload)
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
PoolingResponse, RerankRequest,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
# yapf: enable
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
|
||||
MultiModalDataDict)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob, PromptLogprobs
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
|
||||
random_uuid)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingCompletionRequest, RerankRequest,
|
||||
ClassificationRequest, ScoreRequest,
|
||||
TokenizeCompletionRequest]
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
|
||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
|
||||
TranscriptionRequest]
|
||||
|
||||
AnyResponse = Union[
|
||||
CompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
TranscriptionResponse,
|
||||
TokenizeResponse,
|
||||
PoolingResponse,
|
||||
ClassificationResponse,
|
||||
ScoreResponse,
|
||||
]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
|
||||
|
||||
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
|
||||
|
||||
|
||||
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
|
||||
and "prompt_embeds" in prompt)
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
class RequestProcessingMixin(BaseModel):
|
||||
"""
|
||||
Mixin for request processing,
|
||||
handling prompt preparation and engine input.
|
||||
"""
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
||||
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
||||
list[EngineEmbedsPrompt]]] = []
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ResponseGenerationMixin(BaseModel):
|
||||
"""
|
||||
Mixin for response generation,
|
||||
managing result generators and final batch results.
|
||||
"""
|
||||
result_generator: Optional[AsyncGenerator[tuple[int, Union[
|
||||
RequestOutput, PoolingRequestOutput]], None]] = None
|
||||
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
|
||||
default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
|
||||
Generic[RequestT]):
|
||||
# Shared across all requests
|
||||
request: RequestT
|
||||
raw_request: Optional[Request] = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = Field(default_factory=lambda: int(time.time()))
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# `protected_namespaces` resolves Pydantic v2's warning
|
||||
# on conflict with protected namespace "model_"
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
ClassificationServeContext = ServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
|
||||
chat_template: Optional[str] = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption
|
||||
|
||||
|
||||
# Used to resolve the Pydantic error related to
|
||||
# forward reference of MultiModalDataDict in TokensPrompt
|
||||
RequestProcessingMixin.model_rebuild()
|
||||
ServeContext.model_rebuild()
|
||||
ClassificationServeContext.model_rebuild()
|
||||
EmbeddingServeContext.model_rebuild()
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
so you can easily tell “this ID came from Embedding vs Classification.”
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.models = models
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self._tokenize_prompt_input_async = make_async(
|
||||
self._tokenize_prompt_input, executor=self._tokenizer_executor)
|
||||
self._tokenize_prompt_input_or_inputs_async = make_async(
|
||||
self._tokenize_prompt_input_or_inputs,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override
|
||||
to prepare `ctx` (classification, embedding, etc.).
|
||||
"""
|
||||
return None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
"""
|
||||
Default response builder. Subclass may override this method
|
||||
to return the appropriate response object.
|
||||
"""
|
||||
return self.create_error_response("unimplemented endpoint")
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
|
||||
generation = self._pipeline(ctx)
|
||||
|
||||
async for response in generation:
|
||||
return response
|
||||
|
||||
return self.create_error_response("No response yielded from pipeline")
|
||||
|
||||
async def _pipeline(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
|
||||
"""Execute the request processing pipeline yielding responses."""
|
||||
if error := await self._check_model(ctx.request):
|
||||
yield error
|
||||
if error := self._validate_request(ctx):
|
||||
yield error
|
||||
|
||||
preprocess_ret = await self._preprocess(ctx)
|
||||
if isinstance(preprocess_ret, ErrorResponse):
|
||||
yield preprocess_ret
|
||||
|
||||
generators_ret = await self._prepare_generators(ctx)
|
||||
if isinstance(generators_ret, ErrorResponse):
|
||||
yield generators_ret
|
||||
|
||||
collect_ret = await self._collect_batch(ctx)
|
||||
if isinstance(collect_ret, ErrorResponse):
|
||||
yield collect_ret
|
||||
|
||||
yield self._build_response(ctx)
|
||||
|
||||
def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= self.max_model_len:
|
||||
ctx.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
return None
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Schedule the request and get the result generator."""
|
||||
generators: list[AsyncGenerator[Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
if not hasattr(ctx.request, "to_pooling_params"):
|
||||
return self.create_error_response(
|
||||
"Request type does not support pooling parameters")
|
||||
|
||||
pooling_params = ctx.request.to_pooling_params()
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
if ctx.request_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
ctx.request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
prompt_adapter_request=ctx.prompt_adapter_request)
|
||||
|
||||
# Mypy has an existing bug related to inferring the variance of
|
||||
# TypedDicts with `builtins.enumerate`:
|
||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||
engine_prompt = cast(
|
||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect batch results from the result generator."""
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[Optional[Union[RequestOutput,
|
||||
PoolingRequestOutput]]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
return self.create_error_response(
|
||||
"Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [
|
||||
res for res in final_res_batch if res is not None
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
||||
json_str = json.dumps({
|
||||
"error":
|
||||
self.create_error_response(message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code).model_dump()
|
||||
})
|
||||
return json_str
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
|
||||
error_response = None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [
|
||||
lora.lora_name for lora in self.models.lora_requests
|
||||
]:
|
||||
return None
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
|
||||
load_result := await self.models.resolve_lora(request.model)):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if isinstance(load_result, ErrorResponse) and \
|
||||
load_result.code == HTTPStatus.BAD_REQUEST.value:
|
||||
error_response = load_result
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.models.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
|
||||
return error_response or self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self, request: AnyRequest
|
||||
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
for lora in self.models.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora, None
|
||||
for prompt_adapter in self.models.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: str,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
"do_lower_case", False)):
|
||||
prompt = prompt.lower()
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
||||
elif truncate_prompt_tokens < 0:
|
||||
# Negative means we cap at the model's max length
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=self.max_model_len)
|
||||
else:
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens)
|
||||
|
||||
input_ids = encoded.input_ids
|
||||
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: list[int],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
elif truncate_prompt_tokens < 0:
|
||||
input_ids = prompt_ids[-self.max_model_len:]
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
input_text = tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest, ClassificationRequest)):
|
||||
|
||||
if token_num > self.max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreRequest: "score",
|
||||
ClassificationRequest: "classification"
|
||||
}
|
||||
operation = operations.get(type(request),
|
||||
"embedding generation")
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest)):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
if max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.")
|
||||
elif token_num + max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _tokenize_prompt_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, list[int]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
A simpler implementation of
|
||||
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
||||
that assumes single input.
|
||||
"""
|
||||
return next(
|
||||
self._tokenize_prompt_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
))
|
||||
|
||||
def _tokenize_prompt_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
A simpler implementation of
|
||||
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
||||
that assumes multiple inputs.
|
||||
"""
|
||||
for text in prompt_inputs:
|
||||
if isinstance(text, str):
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _tokenize_prompt_input_or_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
||||
"""
|
||||
Tokenize/detokenize depending on the input format.
|
||||
|
||||
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
||||
, each input can be a string or array of tokens. Note that each request
|
||||
can pass one or more inputs.
|
||||
"""
|
||||
inputs_embeds = list[EmbedsPrompt]()
|
||||
inputs_text = list[TextTokensPrompt]()
|
||||
|
||||
if (isinstance(request, CompletionRequest)
|
||||
and request.prompt_embeds is not None):
|
||||
inputs_embeds.extend(
|
||||
self._load_prompt_embeds(request.prompt_embeds,
|
||||
truncate_prompt_tokens))
|
||||
|
||||
# Empty prompts are okay as long as there are prompt embeddings
|
||||
if input_or_inputs is None or (inputs_embeds
|
||||
and input_or_inputs == ""):
|
||||
return [], inputs_embeds
|
||||
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is False" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
inputs_text.extend([
|
||||
self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens)
|
||||
if prompt_input["is_tokens"] is False else
|
||||
self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs)
|
||||
])
|
||||
|
||||
return inputs_text, inputs_embeds
|
||||
|
||||
@overload
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
|
||||
RerankRequest, ClassificationRequest, ScoreRequest,
|
||||
TokenizeCompletionRequest],
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
|
||||
EngineTokensPrompt, EngineEmbedsPrompt]]]:
|
||||
...
|
||||
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[Union[list[TextTokensPrompt], list[Union[
|
||||
TextTokensPrompt, EmbedsPrompt]]], Union[
|
||||
list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
|
||||
EngineEmbedsPrompt]]]]:
|
||||
if not isinstance(request,
|
||||
CompletionRequest) and input_or_inputs is None:
|
||||
raise ValueError(
|
||||
"Prompt embeds with non-completion requests is not"
|
||||
" currently supported.")
|
||||
|
||||
(request_prompts_text, request_prompts_embeds
|
||||
) = await self._tokenize_prompt_input_or_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
engine_prompts_text = [
|
||||
EngineTokensPrompt(
|
||||
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
||||
for request_prompt_text in request_prompts_text
|
||||
]
|
||||
|
||||
# This check is equivalent to simply checking if
|
||||
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
||||
# overloads to the private helper functions to enable this check.
|
||||
# This overload is needed because only TextPrompts are allowed for
|
||||
# non-completion requests and if we don't add the overload here,
|
||||
# everywhere this function is used outside of serving_completion will
|
||||
# need logic asserting that only text prompts are in the request.
|
||||
if not isinstance(request,
|
||||
CompletionRequest) and input_or_inputs is not None:
|
||||
return request_prompts_text, engine_prompts_text
|
||||
|
||||
engine_prompts_embeds = [
|
||||
EngineEmbedsPrompt(
|
||||
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
||||
for request_prompt_embeds in request_prompts_embeds
|
||||
]
|
||||
|
||||
request_prompts = request_prompts_embeds + request_prompts_text
|
||||
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
||||
return request_prompts, engine_prompts
|
||||
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: ChatLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tool_dicts: Optional[list[dict[str, Any]]] = None,
|
||||
documents: Optional[list[dict[str, str]]] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
||||
list[EngineTokensPrompt]]:
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tool_dicts,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
messages,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
_chat_template_kwargs: dict[str, Any] = dict(
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=documents,
|
||||
)
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
request_prompt: Union[str, list[int]]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
request_prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
else:
|
||||
request_prompt = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
conversation=conversation,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
|
||||
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
|
||||
should_parse_tools = tool_parser is not None and (hasattr(
|
||||
request, "tool_choice") and request.tool_choice != "none")
|
||||
|
||||
if should_parse_tools:
|
||||
if not isinstance(request, ChatCompletionRequest):
|
||||
msg = "Tool usage is only supported for Chat Completions API"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
request = tool_parser(tokenizer).adjust_request( # type: ignore
|
||||
request=request)
|
||||
|
||||
if isinstance(request_prompt, str):
|
||||
prompt_inputs = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request_prompt,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For MistralTokenizer
|
||||
assert is_list_of(request_prompt, int), (
|
||||
"Prompt has to be either a string or a list of token ids")
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(request_prompt),
|
||||
prompt_token_ids=request_prompt)
|
||||
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
|
||||
return conversation, [request_prompt], [engine_prompt]
|
||||
|
||||
def _load_prompt_embeds(
|
||||
self,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
) -> list[EmbedsPrompt]:
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
|
||||
weights_only=True)
|
||||
assert isinstance(
|
||||
tensor,
|
||||
(torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
return {"prompt_embeds": tensor}
|
||||
|
||||
if prompt_embeds:
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [
|
||||
_load_and_validate_embed(embed) for embed in prompt_embeds
|
||||
]
|
||||
else:
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: RequestPrompt,
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
elif isinstance(inputs, list):
|
||||
prompt_token_ids = inputs
|
||||
elif 'prompt_embeds' in inputs:
|
||||
prompt_embeds = inputs.get("prompt_embeds")
|
||||
else:
|
||||
prompt = inputs["prompt"]
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Optional[Mapping[str, str]]:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(raw_request: Optional[Request],
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
default = default or random_uuid()
|
||||
if raw_request is None:
|
||||
return default
|
||||
|
||||
return raw_request.headers.get("X-Request-Id", default)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
return_as_token_id: bool = False) -> str:
|
||||
if return_as_token_id:
|
||||
return f"token_id:{token_id}"
|
||||
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
def _is_model_supported(self, model_name: Optional[str]) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
def _get_model_name(self,
|
||||
model_name: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> str:
|
||||
if lora_request:
|
||||
return lora_request.lora_name
|
||||
if not model_name:
|
||||
return self.models.base_model_paths[0].name
|
||||
return model_name
|
||||
|
||||
|
||||
def clamp_prompt_logprobs(
|
||||
prompt_logprobs: Union[PromptLogprobs,
|
||||
None]) -> Union[PromptLogprobs, None]:
|
||||
if prompt_logprobs is None:
|
||||
return prompt_logprobs
|
||||
|
||||
for logprob_dict in prompt_logprobs:
|
||||
if logprob_dict is None:
|
||||
continue
|
||||
for logprob_values in logprob_dict.values():
|
||||
if logprob_values.logprob == float('-inf'):
|
||||
logprob_values.logprob = -9999.0
|
||||
return prompt_logprobs
|
||||
315
entrypoints/openai/serving_models.py
Normal file
315
entrypoints/openai/serving_models.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIServingModels:
|
||||
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||
|
||||
Handles the routes:
|
||||
- /v1/models
|
||||
- /v1/load_lora_adapter
|
||||
- /v1/unload_lora_adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: list[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[list[LoRAModulePath]] = None,
|
||||
prompt_adapters: Optional[list[PromptAdapterPath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: list[LoRARequest] = []
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
|
||||
):
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name))
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with pathlib.Path(prompt_adapter.local_path,
|
||||
"adapter_config.json").open() as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
if self.static_lora_modules is None:
|
||||
return
|
||||
for lora in self.static_lora_modules:
|
||||
load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
|
||||
lora_name=lora.name)
|
||||
load_result = await self.load_lora_adapter(
|
||||
request=load_request, base_model_name=lora.base_model_name)
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.message)
|
||||
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
|
||||
"""Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora_request is not None:
|
||||
return lora_request.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all
|
||||
adapters"""
|
||||
model_cards = [
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoRAAdapterRequest,
|
||||
base_model_name: Optional[str] = None
|
||||
) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name, lora_path = request.lora_name, request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
lora_request = LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path)
|
||||
if base_model_name is not None and self.is_base_model(base_model_name):
|
||||
lora_request.base_model_name = base_model_name
|
||||
|
||||
# Validate that the adapter can be loaded into the engine
|
||||
# This will also pre-load it for incoming requests
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
except BaseException as e:
|
||||
error_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if "No adapter found" in str(e):
|
||||
error_type = "NotFoundError"
|
||||
status_code = HTTPStatus.NOT_FOUND
|
||||
|
||||
return create_error_response(message=str(e),
|
||||
err_type=error_type,
|
||||
status_code=status_code)
|
||||
|
||||
self.lora_requests.append(lora_request)
|
||||
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
|
||||
lora_path)
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(request
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name = request.lora_name
|
||||
self.lora_requests = [
|
||||
lora_request for lora_request in self.lora_requests
|
||||
if lora_request.lora_name != lora_name
|
||||
]
|
||||
logger.info("Removed LoRA adapter: name '%s'", lora_name)
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been "
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self,
|
||||
request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if either 'lora_name' or 'lora_int_id' is provided
|
||||
if not request.lora_name and not request.lora_int_id:
|
||||
return create_error_response(
|
||||
message=
|
||||
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if not any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(
|
||||
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
for existing in self.lora_requests:
|
||||
if existing.lora_name == lora_name:
|
||||
return existing
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(
|
||||
base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests.append(lora_request)
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name, resolver.__class__.__name__)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.", lora_name,
|
||||
resolver.__class__.__name__, e)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(f"LoRA adapter '{lora_name}' was found "
|
||||
"but could not be loaded."),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
232
entrypoints/openai/serving_pooling.py
Normal file
232
entrypoints/openai/serving_pooling.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_data(
|
||||
output: PoolingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[list[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.data.tolist()
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
pooling_bytes = np.array(output.data, dtype="float32").tobytes()
|
||||
return base64.b64encode(pooling_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[PoolingResponse, ErrorResponse]:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for pooling models")
|
||||
|
||||
if isinstance(request, PoolingChatRequest):
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
# In pooling requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_pooling_response(
|
||||
final_res_batch_checked,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_pooling_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> PoolingResponse:
|
||||
items: list[PoolingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
item = PoolingResponseData(
|
||||
index=idx,
|
||||
data=_get_data(final_res.outputs, encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return PoolingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
433
entrypoints/openai/serving_score.py
Normal file
433
entrypoints/openai/serving_score.py
Normal file
@@ -0,0 +1,433 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage,
|
||||
ScoreRequest, ScoreResponse,
|
||||
ScoreResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id=str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
input_texts = texts_1 + texts_2
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts))
|
||||
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
|
||||
text_token_prompt = \
|
||||
self._validate_input(
|
||||
request,
|
||||
tok_result["input_ids"],
|
||||
input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"]))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: list[Optional[PoolingRequestOutput]] =\
|
||||
[None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_texts_1: list[PoolingRequestOutput] = []
|
||||
emb_texts_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(texts_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_1.append(emb)
|
||||
|
||||
for i in range(len(texts_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_2.append(emb)
|
||||
|
||||
if len(emb_texts_1) == 1:
|
||||
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
|
||||
embed_1=emb_texts_1,
|
||||
embed_2=emb_texts_2)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: Union[AnyTokenizer],
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id=str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
|
||||
if len(texts_1) == 1:
|
||||
texts_1 = texts_1 * len(texts_2)
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)]
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs)
|
||||
for t1, t2 in input_pairs))
|
||||
|
||||
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
|
||||
|
||||
request_prompt = f"{t1}{tokenizer.sep_token}{t2}"
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
text_token_prompt = \
|
||||
self._validate_input(request, input_ids, request_prompt)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[
|
||||
Optional[PoolingRequestOutput]] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
texts_1: Union[str, list[str]],
|
||||
texts_2: Union[str, list[str]],
|
||||
request: Union[ScoreRequest, RerankRequest],
|
||||
request_id: str,
|
||||
raw_request: Optional[Request] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
if isinstance(texts_1, str):
|
||||
texts_1 = [texts_1]
|
||||
if isinstance(texts_2, str):
|
||||
texts_2 = [texts_2]
|
||||
|
||||
_validate_score_input_lens(texts_1, texts_2)
|
||||
|
||||
if self.model_config.is_cross_encoder:
|
||||
return await self._cross_encoding_score(
|
||||
tokenizer=tokenizer,
|
||||
texts_1=texts_1,
|
||||
texts_2=texts_2,
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
else:
|
||||
return await self._embedding_score(
|
||||
tokenizer=tokenizer,
|
||||
texts_1=texts_1,
|
||||
texts_2=texts_2,
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[ScoreResponse, ErrorResponse]:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.text_1,
|
||||
request.text_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
self._get_model_name(request.model),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
top_n = request.top_n if request.top_n > 0 else len(documents)
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self._get_model_name(request.model),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> ScoreResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: list[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: list[str],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
157
entrypoints/openai/serving_tokenization.py
Normal file
157
entrypoints/openai/serving_tokenization.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
tool_dicts = (None if request.tools is None else
|
||||
[tool.model_dump() for tool in request.tools])
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
tool_dicts=tool_dicts,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
input_ids: list[int] = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
self._log_inputs(request_id,
|
||||
request_prompts[i],
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect
|
||||
# tokenization (Unlike in Embeddings API where an error is raised)
|
||||
if isinstance(engine_prompt,
|
||||
dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
token_strs=token_strs,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len)
|
||||
|
||||
async def create_detokenize(
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
# (Unlike in Embeddings API where an error is raised)
|
||||
|
||||
prompt_input = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
424
entrypoints/openai/serving_transcription.py
Normal file
424
entrypoints/openai/serving_transcription.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from math import ceil
|
||||
from typing import Final, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
|
||||
# TODO these configs should live somewhere with the model so we can support
|
||||
# additional ones
|
||||
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"af": "Afrikaans",
|
||||
"ar": "Arabic",
|
||||
"hy": "Armenian",
|
||||
"az": "Azerbaijani",
|
||||
"be": "Belarusian",
|
||||
"bs": "Bosnian",
|
||||
"bg": "Bulgarian",
|
||||
"ca": "Catalan",
|
||||
"zh": "Chinese",
|
||||
"hr": "Croatian",
|
||||
"cs": "Czech",
|
||||
"da": "Danish",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"et": "Estonian",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"gl": "Galician",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"hu": "Hungarian",
|
||||
"is": "Icelandic",
|
||||
"id": "Indonesian",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"kn": "Kannada",
|
||||
"kk": "Kazakh",
|
||||
"ko": "Korean",
|
||||
"lv": "Latvian",
|
||||
"lt": "Lithuanian",
|
||||
"mk": "Macedonian",
|
||||
"ms": "Malay",
|
||||
"mr": "Marathi",
|
||||
"mi": "Maori",
|
||||
"ne": "Nepali",
|
||||
"no": "Norwegian",
|
||||
"fa": "Persian",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ro": "Romanian",
|
||||
"ru": "Russian",
|
||||
"sr": "Serbian",
|
||||
"sk": "Slovak",
|
||||
"sl": "Slovenian",
|
||||
"es": "Spanish",
|
||||
"sw": "Swahili",
|
||||
"sv": "Swedish",
|
||||
"tl": "Tagalog",
|
||||
"ta": "Tamil",
|
||||
"th": "Thai",
|
||||
"tr": "Turkish",
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh"
|
||||
}
|
||||
ISO639_1_OTHER_LANGS = {
|
||||
"lo": "Lao",
|
||||
"jw": "Javanese",
|
||||
"tk": "Turkmen",
|
||||
"yi": "Yiddish",
|
||||
"so": "Somali",
|
||||
"bn": "Bengali",
|
||||
"nn": "Norwegian Nynorsk",
|
||||
"si": "Sinhala",
|
||||
"yo": "Yoruba",
|
||||
"sa": "Sanskrit",
|
||||
"mi": "Māori",
|
||||
"fo": "Faroese", # codespell:ignore
|
||||
"mt": "Maltese",
|
||||
"tg": "Tajik",
|
||||
"mg": "Malagasy",
|
||||
"haw": "Hawaiian",
|
||||
"km": "Khmer",
|
||||
"br": "Breton",
|
||||
"ps": "Pashto",
|
||||
"ln": "Lingala",
|
||||
"la": "Latin",
|
||||
"ml": "Malayalam",
|
||||
"sq": "Albanian",
|
||||
"su": "Sundanese",
|
||||
"eu": "Basque",
|
||||
"ka": "Georgian",
|
||||
"uz": "Uzbek",
|
||||
"sn": "Shona",
|
||||
"ht": "Haitian",
|
||||
"as": "Assamese",
|
||||
"mn": "Mongolian",
|
||||
"te": "Telugu",
|
||||
"pa": "Panjabi",
|
||||
"tt": "Tatar",
|
||||
"gu": "Gujarati",
|
||||
"oc": "Occitan",
|
||||
"ha": "Hausa",
|
||||
"ba": "Bashkir",
|
||||
"my": "Burmese",
|
||||
"sd": "Sindhi",
|
||||
"am": "Amharic",
|
||||
"lb": "Luxembourgish",
|
||||
"bo": "Tibetan"
|
||||
}
|
||||
|
||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||
# TODO configurable
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
processor = cached_get_processor(model_config.model)
|
||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
||||
self.model_sr = processor.feature_extractor.sampling_rate
|
||||
self.hop_length = processor.feature_extractor.hop_length
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
async def _preprocess_transcription(
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[PromptType, float]:
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
|
||||
if request.language:
|
||||
if request.language in ISO639_1_SUPPORTED_LANGS:
|
||||
pass
|
||||
elif request.language in ISO639_1_OTHER_LANGS:
|
||||
logger.warning(
|
||||
"The selected language %s has limited accuracy with"
|
||||
" reported WER>=0.5. Results may be less accurate "
|
||||
"for this choice.", request.language)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported language: {request.language}."
|
||||
"Language should be one of:" +
|
||||
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
y, sr = librosa.load(bytes_)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
if duration > self.max_audio_clip_s:
|
||||
raise ValueError(
|
||||
f"Maximum clip duration ({self.max_audio_clip_s}s) "
|
||||
"exceeded.")
|
||||
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (y, sr),
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||
}
|
||||
return cast(PromptType, prompt), duration
|
||||
|
||||
# TODO (varun) : Make verbose response work !
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||
ErrorResponse]:
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ['text', 'json']:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`")
|
||||
|
||||
request_id = f"trsc-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support LoRA for Transcription.")
|
||||
if prompt_adapter_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support PromptAdapter for Transcription."
|
||||
)
|
||||
|
||||
prompt, duration_s = await self._preprocess_transcription(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
prompt['decoder_prompt'], # type: ignore
|
||||
params=sampling_params,
|
||||
lora_request=None,
|
||||
prompt_adapter_request=None)
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return self.transcription_stream_generator(request,
|
||||
result_generator,
|
||||
request_id,
|
||||
request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert result_generator is not None
|
||||
async for op in result_generator:
|
||||
result = op
|
||||
return TranscriptionResponse(text=result.outputs[0].text)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
chunk_object_type: Final = "transcription.chunk"
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = request.stream_include_usage \
|
||||
if request.stream_include_usage else False
|
||||
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||
# Could be negative when language token is not specified.
|
||||
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
|
||||
# NOTE(NickLucche) user can't pass encoder prompts directly
|
||||
# at least not to Whisper. One indicator of the encoder
|
||||
# amount of processing is the log-mel spectogram length.
|
||||
num_prompt_tokens += ceil(audio_duration_s *
|
||||
self.model_sr / self.hop_length)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = TranscriptionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
|
||||
final_usage_chunk = TranscriptionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
23
entrypoints/openai/tool_parsers/__init__.py
Normal file
23
entrypoints/openai/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||
from .granite_tool_parser import GraniteToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .jamba_tool_parser import JambaToolParser
|
||||
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
|
||||
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
||||
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser"
|
||||
]
|
||||
164
entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
164
entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from functools import cached_property
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import import_from_path, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""
|
||||
Abstract ToolParser class that should not be used directly. Provided
|
||||
properties and methods should be used in
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
Used for non-streaming responses where we have the entire model response
|
||||
available before sending to the client.
|
||||
Static because it's stateless.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls has not been implemented!")
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
from an incomplete response; for use when handling tool calls and
|
||||
streaming. Has to be an instance method because it requires state -
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||
"implemented!")
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
tool_parsers: dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name) -> type:
|
||||
"""
|
||||
Get tool parser by name which is registered by `register_module`.
|
||||
|
||||
Raise a KeyError exception if the name is not registered.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
module: type,
|
||||
module_name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True) -> None:
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(
|
||||
f'module must be subclass of ToolParser, but got {type(module)}'
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed_module = cls.tool_parsers[name]
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'at {existed_module.__module__}')
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[type, None] = None) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)
|
||||
or is_list_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be None, an instance of str, or a sequence of str, '
|
||||
f'but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user-defined tool parser by the path of the tool parser define
|
||||
file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
|
||||
try:
|
||||
import_from_path(module_name, plugin_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to load module '%s' from %s.",
|
||||
module_name, plugin_path)
|
||||
return
|
||||
370
entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py
Normal file
370
entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("deepseek_v3")
|
||||
class DeepSeekV3ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool▁calls▁begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool▁calls▁end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool▁call▁begin|>"
|
||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*[^\n`])"
|
||||
)
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n")
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"DeepSeek-V3 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(
|
||||
model_output)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
tool_type, function_name, function_args = match
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type=tool_type,
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args),
|
||||
))
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_calls_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_type, tool_name, tool_args = (
|
||||
current_tool_call_matches.groups())
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_type, tool_name = (
|
||||
current_tool_call_name_matches.groups())
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
259
entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
Normal file
259
entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecoder
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
||||
find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("granite-20b-fc")
|
||||
class Granite20bFCToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for the granite-20b-functioncalling model intended
|
||||
for use with the examples/tool_chat_template_granite20b_fc.jinja
|
||||
template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.bot_token = "<function_call>"
|
||||
self.tool_start_token = self.bot_token
|
||||
self.tool_call_regex = re.compile(r"<function_call>\s*")
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
if self.tool_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
dec = JSONDecoder()
|
||||
try:
|
||||
matches = list(self.tool_call_regex.finditer(model_output))
|
||||
logger.debug("Found %d tool call matches", len(matches))
|
||||
|
||||
raw_function_calls = []
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
# position after the <function_call> tag
|
||||
start_of_json = match.end()
|
||||
# end_index == the start of the next function call
|
||||
# (if exists)
|
||||
next_function_call_start = (matches[i + 1].start() if i +
|
||||
1 < len(matches) else None)
|
||||
|
||||
raw_function_calls.append(
|
||||
dec.raw_decode(
|
||||
model_output[start_of_json:next_function_call_start])
|
||||
[0])
|
||||
|
||||
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False),
|
||||
),
|
||||
) for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.find(self.bot_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s", e)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if len(current_text) < len(
|
||||
self.bot_token) and self.bot_token.startswith(current_text):
|
||||
return None
|
||||
|
||||
if not current_text.startswith(self.bot_token):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
start_idx = len(self.bot_token)
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
|
||||
while start_idx < len(current_text):
|
||||
(obj,
|
||||
end_idx) = partial_json_loads(current_text[start_idx:],
|
||||
flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx:start_idx +
|
||||
end_idx]))
|
||||
start_idx += end_idx
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
start_idx += len(self.bot_token)
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
237
entrypoints/openai/tool_parsers/granite_tool_parser.py
Normal file
237
entrypoints/openai/tool_parsers/granite_tool_parser.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
||||
find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("granite")
|
||||
class GraniteToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for the granite 3.0 models. Intended
|
||||
for use with the examples/tool_chat_template_granite.jinja
|
||||
template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser granite
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
# for granite 3.0, the token `<|tool_call|>`
|
||||
self.bot_token = "<|tool_call|>"
|
||||
# for granite 3.1, the string `<tool_call>`
|
||||
self.bot_string = "<tool_call>"
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
stripped = model_output.strip()\
|
||||
.removeprefix(self.bot_token)\
|
||||
.removeprefix(self.bot_string)\
|
||||
.lstrip()
|
||||
if not stripped or stripped[0] != '[':
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
try:
|
||||
raw_function_calls = json.loads(stripped)
|
||||
if not isinstance(raw_function_calls, list):
|
||||
raise Exception(
|
||||
f"Expected dict or list, got {type(raw_function_calls)}")
|
||||
|
||||
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False),
|
||||
),
|
||||
) for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s", e)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
start_idx = consume_space(0, current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_token):
|
||||
start_idx = consume_space(start_idx + len(self.bot_token),
|
||||
current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_string):
|
||||
start_idx = consume_space(start_idx + len(self.bot_string),
|
||||
current_text)
|
||||
if not current_text or start_idx >= len(current_text)\
|
||||
or current_text[start_idx] != '[':
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = None
|
||||
is_complete = None
|
||||
try:
|
||||
tool_calls, end_idx = partial_json_loads(
|
||||
current_text[start_idx:], flags)
|
||||
if type(tool_calls) is list:
|
||||
tool_call_arr = tool_calls
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
is_complete = [True] * len(tool_calls)
|
||||
if not is_complete_json(
|
||||
current_text[start_idx:start_idx + end_idx]):
|
||||
is_complete[-1] = False
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if not tool_call_arr:
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id]
|
||||
|
||||
delta = None
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
if len(tool_call_arr) > self.current_tool_id + 1:
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
371
entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
371
entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
@@ -0,0 +1,371 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("hermes")
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.error(
|
||||
"Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = (
|
||||
self.tool_call_regex.findall(model_output))
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = [
|
||||
json.loads(match[0] if match[0] else match[1])
|
||||
for match in function_call_tuples
|
||||
]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False)))
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case: if tool open & close tag counts don't match, we're doing
|
||||
# imaginary "else" block here
|
||||
# something with tools with this diff.
|
||||
# flags for partial JSON parting. exported constants from
|
||||
# "Allow" are handled via BIT MASK
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if (self.prev_tool_call_arr is None
|
||||
or len(self.prev_tool_call_arr) == 0):
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = diff.encode('utf-8').decode(
|
||||
'unicode_escape') if diff is str else diff
|
||||
if ('"}' not in delta_text):
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s", diff)
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
try:
|
||||
|
||||
current_tool_call = partial_json_parser.loads(
|
||||
tool_call_portion or "{}",
|
||||
flags) if tool_call_portion else None
|
||||
logger.debug("Parsed tool call %s", current_tool_call)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.debug("unable to parse JSON")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if (current_tool_call is None):
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
else:
|
||||
return None
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = DeltaMessage(content=delta_text) \
|
||||
if text_portion is not None else None
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = (
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("finding %s in %s", delta_text,
|
||||
cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current
|
||||
if (delta_text not in cur_arguments_json[:-2]):
|
||||
return None
|
||||
args_delta_start_loc = cur_arguments_json[:-2]. \
|
||||
rindex(delta_text) + \
|
||||
len(delta_text)
|
||||
|
||||
# use that to find the actual delta
|
||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= arguments_delta
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if isinstance(delta_text, str) and len(delta_text.rstrip(
|
||||
)) >= 1 and delta_text.rstrip()[-1] == '}':
|
||||
delta_text = delta_text.rstrip()[:-1]
|
||||
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_text).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= delta_text
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = \
|
||||
current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
216
entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
216
entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["internlm"])
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicated the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def get_argments(self, obj):
|
||||
if "parameters" in obj:
|
||||
return obj.get("parameters")
|
||||
elif "arguments" in obj:
|
||||
return obj.get("arguments")
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
if '<|action_start|>' not in current_text:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
# if the tool call is sended, return a empty delta message
|
||||
# to make sure the finish_reason will be send correctly.
|
||||
if self.current_tool_id > 0:
|
||||
return DeltaMessage(content='')
|
||||
|
||||
last_pos = self.position
|
||||
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
||||
return None
|
||||
|
||||
new_delta = current_text[last_pos:]
|
||||
text, action = new_delta.split('<|action_start|><|plugin|>')
|
||||
|
||||
if len(text) > 0:
|
||||
self.position = self.position + len(text)
|
||||
return DeltaMessage(content=text)
|
||||
|
||||
action = action.strip()
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
parsable_arr = action
|
||||
|
||||
# tool calls are generated in an object in inernlm2
|
||||
# it's not support parallel tool calls
|
||||
try:
|
||||
tool_call_arr: dict = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = tool_call_arr.get("name")
|
||||
if function_name:
|
||||
self.current_tool_id = self.current_tool_id + 1
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
self.streamed_args_for_tool.append("")
|
||||
else:
|
||||
delta = None
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.get_argments(
|
||||
self.prev_tool_call_arr[self.current_tool_id])
|
||||
cur_arguments = self.get_argments(tool_call_arr)
|
||||
|
||||
# not arguments generated
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
# will never happen
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
# first time to get parameters
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(delta_text) +
|
||||
len(delta_text)]
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
# both prev and cur parameters, send the increase parameters
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
|
||||
self.prev_tool_call_arr = [tool_call_arr]
|
||||
return delta
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
text = model_output
|
||||
tools = request.tools
|
||||
if '<|action_start|><|plugin|>' in text:
|
||||
text, action = text.split('<|action_start|><|plugin|>')
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
action = action[action.find('{'):]
|
||||
action_dict = json.loads(action)
|
||||
name, parameters = action_dict['name'], json.dumps(
|
||||
action_dict.get('parameters', action_dict.get('arguments',
|
||||
{})),
|
||||
ensure_ascii=False)
|
||||
|
||||
if not tools or name not in [t.function.name for t in tools]:
|
||||
ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(name=name, arguments=parameters))
|
||||
]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=text if len(text) > 0 else None)
|
||||
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
308
entrypoints/openai/tool_parsers/jamba_tool_parser.py
Normal file
308
entrypoints/openai/tool_parsers/jamba_tool_parser.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("jamba")
|
||||
class JambaToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Detected a MistralTokenizer tokenizer when using a Jamba model"
|
||||
)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<tool_calls>"
|
||||
self.tool_calls_end_token: str = "</tool_calls>"
|
||||
|
||||
self.tool_calls_regex = re.compile(
|
||||
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
|
||||
re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Jamba Tool parser could not locate tool calls start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because jamba use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# use a regex to find the tool call between the tags
|
||||
function_calls = self.tool_calls_regex.findall(model_output)[0]
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = json.loads(function_calls)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False),
|
||||
)) for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if
|
||||
(len(content) > 0 and content != " ") else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.tool_calls_start_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the start of tool calls token which means
|
||||
# the start of tool calling
|
||||
if (self.tool_calls_start_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion and don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# Extract the tool calls between the special tool call tokens
|
||||
parsable_arr = current_text.split(
|
||||
self.tool_calls_start_token)[-1].split(
|
||||
self.tool_calls_end_token)[0]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: list[dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff, ensure_ascii=False).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
316
entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py
Normal file
316
entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama4_pythonic")
|
||||
class Llama4PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Toolcall parser for Llama4 that produce tool calls in a pythonic style
|
||||
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
|
||||
"""
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
# as Llama 4 model sometime will output those tokens
|
||||
if model_output.startswith("<|python_start|>"):
|
||||
model_output = model_output[len("<|python_start|>"):]
|
||||
model_output = model_output.replace("<|python_end|>", "")
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not current_text.startswith("[") and not current_text.startswith(
|
||||
"<|python_start|>"):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
if current_text.startswith("<|python_start|>"):
|
||||
current_text = current_text[len("<|python_start|>"):]
|
||||
if current_text.endswith("<|python_end|>"):
|
||||
current_text = current_text[:current_text.
|
||||
rfind("<|python_end|>")]
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(
|
||||
tool_calls) - 1 or ")]" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = (added_text[:-2]
|
||||
if not new_call_complete else "")
|
||||
if not new_call_complete and added_text[-2] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
|
||||
new_call, index, withheld_suffix)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (delta.function is not None
|
||||
and delta.function.arguments is not None):
|
||||
self.streamed_args_for_tool[
|
||||
index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining it's final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content='')
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError(
|
||||
"Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(arguments)))
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[:text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[:text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
|
||||
"[") and not text.endswith(")"):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
||||
index: int,
|
||||
withheld_suffix: str) -> Union[DeltaToolCall, None]:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[:-len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
))
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args):]
|
||||
return DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(
|
||||
arguments=arg_diff)) if arg_diff else None
|
||||
267
entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
267
entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecoder
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama3_json")
|
||||
@ToolParserManager.register_module("llama4_json")
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Llama 3.1 models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser llama3_json
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "<|python_tag|>"
|
||||
self.bot_token_id = tokenizer.encode(self.bot_token,
|
||||
add_special_tokens=False)[0]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if not (model_output.startswith(self.bot_token)
|
||||
or model_output.startswith('{')):
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
dec = JSONDecoder()
|
||||
function_call_arr = []
|
||||
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = len(self.bot_token) if model_output.startswith(
|
||||
self.bot_token) else 0
|
||||
while start_idx < len(model_output):
|
||||
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
|
||||
start_idx += end_idx + len('; ')
|
||||
function_call_arr.append(obj)
|
||||
|
||||
tool_calls: list[ToolCall] = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"] \
|
||||
if "arguments" in raw_function_call \
|
||||
else raw_function_call["parameters"],
|
||||
ensure_ascii=False)))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
ret = ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None)
|
||||
return ret
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not (current_text.startswith(self.bot_token)
|
||||
or current_text.startswith('{')):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = len(self.bot_token) if current_text.startswith(
|
||||
self.bot_token) else 0
|
||||
while start_idx < len(current_text):
|
||||
(obj,
|
||||
end_idx) = partial_json_loads(current_text[start_idx:],
|
||||
flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx:start_idx +
|
||||
end_idx]))
|
||||
start_idx += end_idx + len('; ')
|
||||
# depending on the prompt Llama can use
|
||||
# either arguments or parameters
|
||||
if "parameters" in obj:
|
||||
assert "arguments" not in obj, \
|
||||
"model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
369
entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
369
entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
@@ -0,0 +1,369 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from random import choices
|
||||
from string import ascii_letters, digits
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
|
||||
class MistralToolCall(ToolCall):
|
||||
id: str = Field(
|
||||
default_factory=lambda: MistralToolCall.generate_random_id())
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
@staticmethod
|
||||
def is_valid_id(id: str) -> bool:
|
||||
return id.isalnum() and len(id) == 9
|
||||
|
||||
|
||||
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
|
||||
return isinstance(model_tokenizer, MistralTokenizer) \
|
||||
and model_tokenizer.version >= 11
|
||||
|
||||
|
||||
@ToolParserManager.register_module("mistral")
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
|
||||
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
|
||||
- the examples/tool_chat_template_mistral.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||
"model...")
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
|
||||
re.DOTALL)
|
||||
else:
|
||||
self.fn_name_regex = None
|
||||
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
"the tokenizer!")
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if not isinstance(
|
||||
self.model_tokenizer, MistralTokenizer
|
||||
) and request.tools and request.tool_choice != 'none':
|
||||
# Do not skip special tokens when using chat template
|
||||
# with Mistral parser as TOOL_CALL token is needed
|
||||
# for tool detection.
|
||||
# Note: we don't want skip_special_tokens=False
|
||||
# with MistralTokenizer as it is incompatible
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
make sure your tool call arguments don't ever include quotes!
|
||||
"""
|
||||
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if self.bot_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
# first remove the BOT token
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
|
||||
try:
|
||||
# we first try to directly load the json as parsing very nested
|
||||
# jsons is difficult
|
||||
try:
|
||||
if self.fn_name_regex:
|
||||
matches = self.fn_name_regex.findall(tool_content)
|
||||
|
||||
function_call_arr = []
|
||||
for match in matches:
|
||||
fn_name = match[0]
|
||||
args = match[1]
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append({
|
||||
"name": fn_name,
|
||||
"arguments": json.loads(args)
|
||||
})
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
# use a regex to find the part corresponding to the tool call.
|
||||
# NOTE: This use case should not happen if the model is trained
|
||||
# correctly. It's a easy possible fix so it's included, but
|
||||
# can be brittle for very complex / highly nested tool calls
|
||||
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
|
||||
# Tool Call
|
||||
tool_calls: list[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"],
|
||||
ensure_ascii=False)))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
content = model_output.split(self.bot_token)[0]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if len(content) > 0 else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=tool_content)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.bot_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the BOT token which means the start of tool
|
||||
# calling
|
||||
if (self.bot_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion any don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# replace BOT token with empty string, and convert single quotes
|
||||
# to double to allow parsing as JSON since mistral uses single
|
||||
# quotes instead of double for tool calls
|
||||
parsable_arr = current_text.split(self.bot_token)[-1]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: list[dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff, ensure_ascii=False).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=MistralToolCall.generate_random_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
if ('"}' in new_text):
|
||||
new_text = new_text[:new_text.rindex('"}')]
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)[:-2]
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
if (new_text not in cur_arguments_json):
|
||||
return None
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
rindex(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
112
entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
Normal file
112
entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("phi4_mini_json")
|
||||
class Phi4MiniJsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for phi-4-mini models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token: str = "functools"
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
logger.debug("Model output: %s", model_output)
|
||||
|
||||
pattern = r'functools\[(.*?)\]'
|
||||
matches = re.search(pattern, model_output, re.DOTALL)
|
||||
|
||||
if not matches:
|
||||
logger.debug("No function calls found")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
function_call_arr: list[dict[str, Any]] = []
|
||||
try:
|
||||
json_content = '[' + matches.group(1) + ']'
|
||||
|
||||
function_call_arr = json.loads(json_content)
|
||||
logger.debug("Successfully extracted %d function calls",
|
||||
len(function_call_arr))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to parse function calls from model output. "
|
||||
"Error: %s", str(e))
|
||||
|
||||
tool_calls: list[ToolCall] = [
|
||||
ToolCall(
|
||||
id=random_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
raw_function_call["arguments"]
|
||||
if "arguments" in raw_function_call else
|
||||
raw_function_call["parameters"],
|
||||
ensure_ascii=False),
|
||||
)) for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
ret = ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None)
|
||||
return ret
|
||||
|
||||
except Exception:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Optional[DeltaMessage]:
|
||||
|
||||
return None
|
||||
308
entrypoints/openai/tool_parsers/pythonic_tool_parser.py
Normal file
308
entrypoints/openai/tool_parsers/pythonic_tool_parser.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ToolParserManager.register_module("pythonic")
|
||||
class PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for models that produce tool calls in a pythonic style,
|
||||
such as Llama 3.2 and Llama 4 models.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
|
||||
"""
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not current_text.startswith("["):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(
|
||||
tool_calls) - 1 or ")]" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = (added_text[:-2]
|
||||
if not new_call_complete else "")
|
||||
if not new_call_complete and added_text[-2] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
|
||||
new_call, index, withheld_suffix)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (delta.function is not None
|
||||
and delta.function.arguments is not None):
|
||||
self.streamed_args_for_tool[
|
||||
index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining it's final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content='')
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError(
|
||||
"Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(arguments,
|
||||
ensure_ascii=False)),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[:text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[:text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
|
||||
"[") and not text.endswith(")"):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
||||
index: int,
|
||||
withheld_suffix: str) -> Union[DeltaToolCall, None]:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[:-len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
))
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args):]
|
||||
return DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(
|
||||
arguments=arg_diff)) if arg_diff else None
|
||||
124
entrypoints/openai/tool_parsers/utils.py
Normal file
124
entrypoints/openai/tool_parsers/utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, '', 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string: str, substring: str) -> list[int]:
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecoder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def consume_space(i: int, s: str) -> int:
|
||||
while i < len(s) and s[i].isspace():
|
||||
i += 1
|
||||
return i
|
||||
50
entrypoints/score_utils.py
Normal file
50
entrypoints/score_utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
from torch.nn import CosineSimilarity
|
||||
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
|
||||
def _cosine_similarity(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
embed_1: list[PoolingRequestOutput],
|
||||
embed_2: list[PoolingRequestOutput],
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
scorer = CosineSimilarity(0)
|
||||
scores: Union[list[PoolingRequestOutput]] = []
|
||||
|
||||
for emb_1, emb_2 in zip(embed_1, embed_2):
|
||||
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
|
||||
|
||||
padding = []
|
||||
if (pad_token_id := getattr(tokenizer, "pad_token_id",
|
||||
None)) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=pair_score,
|
||||
prompt_token_ids=tokens,
|
||||
finished=True))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _validate_score_input_lens(
|
||||
texts_1: Union[list[str], list[dict]],
|
||||
texts_2: Union[list[str], list[dict]],
|
||||
):
|
||||
if len(texts_1) > 1 and len(texts_1) != len(texts_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(texts_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(texts_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
75
entrypoints/ssl.py
Normal file
75
entrypoints/ssl.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from ssl import SSLContext
|
||||
from typing import Callable, Optional
|
||||
|
||||
from watchfiles import Change, awatch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SSLCertRefresher:
|
||||
"""A class that monitors SSL certificate files and
|
||||
reloads them when they change.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ssl_context: SSLContext,
|
||||
key_path: Optional[str] = None,
|
||||
cert_path: Optional[str] = None,
|
||||
ca_path: Optional[str] = None) -> None:
|
||||
self.ssl = ssl_context
|
||||
self.key_path = key_path
|
||||
self.cert_path = cert_path
|
||||
self.ca_path = ca_path
|
||||
|
||||
# Setup certification chain watcher
|
||||
def update_ssl_cert_chain(change: Change, file_path: str) -> None:
|
||||
logger.info("Reloading SSL certificate chain")
|
||||
assert self.key_path and self.cert_path
|
||||
self.ssl.load_cert_chain(self.cert_path, self.key_path)
|
||||
|
||||
self.watch_ssl_cert_task = None
|
||||
if self.key_path and self.cert_path:
|
||||
self.watch_ssl_cert_task = asyncio.create_task(
|
||||
self._watch_files([self.key_path, self.cert_path],
|
||||
update_ssl_cert_chain))
|
||||
|
||||
# Setup CA files watcher
|
||||
def update_ssl_ca(change: Change, file_path: str) -> None:
|
||||
logger.info("Reloading SSL CA certificates")
|
||||
assert self.ca_path
|
||||
self.ssl.load_verify_locations(self.ca_path)
|
||||
|
||||
self.watch_ssl_ca_task = None
|
||||
if self.ca_path:
|
||||
self.watch_ssl_ca_task = asyncio.create_task(
|
||||
self._watch_files([self.ca_path], update_ssl_ca))
|
||||
|
||||
async def _watch_files(self, paths, fun: Callable[[Change, str],
|
||||
None]) -> None:
|
||||
"""Watch multiple file paths asynchronously."""
|
||||
logger.info("SSLCertRefresher monitors files: %s", paths)
|
||||
async for changes in awatch(*paths):
|
||||
try:
|
||||
for change, file_path in changes:
|
||||
logger.info("File change detected: %s - %s", change.name,
|
||||
file_path)
|
||||
fun(change, file_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"SSLCertRefresher failed taking action on file change. "
|
||||
"Error: %s", e)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop watching files."""
|
||||
if self.watch_ssl_cert_task:
|
||||
self.watch_ssl_cert_task.cancel()
|
||||
self.watch_ssl_cert_task = None
|
||||
if self.watch_ssl_ca_task:
|
||||
self.watch_ssl_ca_task.cancel()
|
||||
self.watch_ssl_ca_task = None
|
||||
233
entrypoints/utils.py
Normal file
233
entrypoints/utils.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
VLLM_SUBCMD_PARSER_EPILOG = (
|
||||
"Tip: Use `vllm [serve|run-batch] --help=<keyword>` "
|
||||
"to explore arguments from help.\n"
|
||||
" - To view a argument group: --help=ModelConfig\n"
|
||||
" - To view a single argument: --help=max-num-seqs\n"
|
||||
" - To search by keyword: --help=max\n"
|
||||
" - To list all groups: --help=listgroup")
|
||||
|
||||
|
||||
async def listen_for_disconnect(request: Request) -> None:
|
||||
"""Returns if a disconnect message is received"""
|
||||
while True:
|
||||
message = await request.receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
if request.app.state.enable_server_load_tracking:
|
||||
# on timeout/cancellation the BackgroundTask in load_aware_call
|
||||
# cannot decrement the server load metrics.
|
||||
# Must be decremented by with_cancellation instead.
|
||||
request.app.state.server_load_metrics -= 1
|
||||
break
|
||||
|
||||
|
||||
def with_cancellation(handler_func):
|
||||
"""Decorator that allows a route handler to be cancelled by client
|
||||
disconnections.
|
||||
|
||||
This does _not_ use request.is_disconnected, which does not work with
|
||||
middleware. Instead this follows the pattern from
|
||||
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
|
||||
to wait for an http disconnect message, and the other to do the work that we
|
||||
want done. When the first task finishes, the other is cancelled.
|
||||
|
||||
A core assumption of this method is that the body of the request has already
|
||||
been read. This is a safe assumption to make for fastapi handlers that have
|
||||
already parsed the body of the request into a pydantic model for us.
|
||||
This decorator is unsafe to use elsewhere, as it will consume and throw away
|
||||
all incoming messages for the request while it looks for a disconnect
|
||||
message.
|
||||
|
||||
In the case where a `StreamingResponse` is returned by the handler, this
|
||||
wrapper will stop listening for disconnects and instead the response object
|
||||
will start listening for disconnects.
|
||||
"""
|
||||
|
||||
# Functools.wraps is required for this wrapper to appear to fastapi as a
|
||||
# normal route handler, with the correct request type hinting.
|
||||
@functools.wraps(handler_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
|
||||
# The request is either the second positional arg or `raw_request`
|
||||
request = args[1] if len(args) > 1 else kwargs["raw_request"]
|
||||
|
||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if handler_task in done:
|
||||
return handler_task.result()
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def decrement_server_load(request: Request):
|
||||
request.app.state.server_load_metrics -= 1
|
||||
|
||||
|
||||
def load_aware_call(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
raw_request = kwargs.get("raw_request",
|
||||
args[1] if len(args) > 1 else None)
|
||||
|
||||
if raw_request is None:
|
||||
raise ValueError(
|
||||
"raw_request required when server load tracking is enabled")
|
||||
|
||||
if not raw_request.app.state.enable_server_load_tracking:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
raw_request.app.state.server_load_metrics += 1
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
except Exception:
|
||||
raw_request.app.state.server_load_metrics -= 1
|
||||
raise
|
||||
|
||||
if isinstance(response, (JSONResponse, StreamingResponse)):
|
||||
if response.background is None:
|
||||
response.background = BackgroundTask(decrement_server_load,
|
||||
raw_request)
|
||||
elif isinstance(response.background, BackgroundTasks):
|
||||
response.background.add_task(decrement_server_load,
|
||||
raw_request)
|
||||
elif isinstance(response.background, BackgroundTask):
|
||||
# Convert the single BackgroundTask to BackgroundTasks
|
||||
# and chain the decrement_server_load task to it
|
||||
tasks = BackgroundTasks()
|
||||
tasks.add_task(response.background.func,
|
||||
*response.background.args,
|
||||
**response.background.kwargs)
|
||||
tasks.add_task(decrement_server_load, raw_request)
|
||||
response.background = tasks
|
||||
else:
|
||||
raw_request.app.state.server_load_metrics -= 1
|
||||
|
||||
return response
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def cli_env_setup():
|
||||
# The safest multiprocessing method is `spawn`, as the default `fork` method
|
||||
# is not compatible with some accelerators. The default method will be
|
||||
# changing in future versions of Python, so we should use it explicitly when
|
||||
# possible.
|
||||
#
|
||||
# We only set it here in the CLI entrypoint, because changing to `spawn`
|
||||
# could break some existing code using vLLM as a library. `spawn` will cause
|
||||
# unexpected behavior if the code is not protected by
|
||||
# `if __name__ == "__main__":`.
|
||||
#
|
||||
# References:
|
||||
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
|
||||
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
|
||||
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
|
||||
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
|
||||
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
|
||||
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def _validate_truncation_size(
|
||||
max_model_len: int,
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[int]:
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= -1:
|
||||
truncate_prompt_tokens = max_model_len
|
||||
|
||||
if truncate_prompt_tokens > max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
|
||||
if tokenization_kwargs is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
|
||||
def show_filtered_argument_or_group_from_help(parser, subcommand_name):
|
||||
import sys
|
||||
|
||||
# Only handle --help=<keyword> for the current subcommand.
|
||||
# Since subparser_init() runs for all subcommands during CLI setup,
|
||||
# we skip processing if the subcommand name is not in sys.argv.
|
||||
if subcommand_name not in sys.argv:
|
||||
return
|
||||
|
||||
for arg in sys.argv:
|
||||
if arg.startswith('--help='):
|
||||
search_keyword = arg.split('=', 1)[1]
|
||||
|
||||
# List available groups
|
||||
if search_keyword == 'listgroup':
|
||||
print("\nAvailable argument groups:")
|
||||
for group in parser._action_groups:
|
||||
if group.title and not group.title.startswith(
|
||||
"positional arguments"):
|
||||
print(f" - {group.title}")
|
||||
if group.description:
|
||||
print(" " + group.description.strip())
|
||||
print()
|
||||
sys.exit(0)
|
||||
|
||||
# For group search
|
||||
formatter = parser._get_formatter()
|
||||
for group in parser._action_groups:
|
||||
if group.title and group.title.lower() == search_keyword.lower(
|
||||
):
|
||||
formatter.start_section(group.title)
|
||||
formatter.add_text(group.description)
|
||||
formatter.add_arguments(group._group_actions)
|
||||
formatter.end_section()
|
||||
print(formatter.format_help())
|
||||
sys.exit(0)
|
||||
|
||||
# For single arg
|
||||
matched_actions = []
|
||||
|
||||
for group in parser._action_groups:
|
||||
for action in group._group_actions:
|
||||
# search option name
|
||||
if any(search_keyword.lower() in opt.lower()
|
||||
for opt in action.option_strings):
|
||||
matched_actions.append(action)
|
||||
|
||||
if matched_actions:
|
||||
print(f"\nParameters matching '{search_keyword}':\n")
|
||||
formatter = parser._get_formatter()
|
||||
formatter.add_arguments(matched_actions)
|
||||
print(formatter.format_help())
|
||||
sys.exit(0)
|
||||
|
||||
print(f"\nNo group or parameter matching '{search_keyword}'")
|
||||
print("Tip: use `--help=listgroup` to view all groups.")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user