init
This commit is contained in:
0
vllm/entrypoints/openai/__init__.py
Normal file
0
vllm/entrypoints/openai/__init__.py
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-312.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-312.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
1953
vllm/entrypoints/openai/api_server.py
Normal file
1953
vllm/entrypoints/openai/api_server.py
Normal file
File diff suppressed because it is too large
Load Diff
288
vllm/entrypoints/openai/cli_args.py
Normal file
288
vllm/entrypoints/openai/cli_args.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import field
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import config
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: list[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, ""]: # Skip if item is None or empty string
|
||||
continue
|
||||
if "=" in item and "," not in item: # Old format: name=path
|
||||
name, path = item.split("=")
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FrontendArgs:
|
||||
"""Arguments for the OpenAI-compatible frontend server."""
|
||||
host: Optional[str] = None
|
||||
"""Host name."""
|
||||
port: int = 8000
|
||||
"""Port number."""
|
||||
uds: Optional[str] = None
|
||||
"""Unix domain socket path. If set, host and port arguments are ignored."""
|
||||
uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical",
|
||||
"trace"] = "info"
|
||||
"""Log level for uvicorn."""
|
||||
disable_uvicorn_access_log: bool = False
|
||||
"""Disable uvicorn access log."""
|
||||
allow_credentials: bool = False
|
||||
"""Allow credentials."""
|
||||
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed origins."""
|
||||
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed methods."""
|
||||
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed headers."""
|
||||
api_key: Optional[list[str]] = None
|
||||
"""If provided, the server will require one of these keys to be presented in
|
||||
the header."""
|
||||
lora_modules: Optional[list[LoRAModulePath]] = None
|
||||
"""LoRA modules configurations in either 'name=path' format or JSON format
|
||||
or JSON list format. Example (old format): `'name=path'` Example (new
|
||||
format): `{\"name\": \"name\", \"path\": \"lora_path\",
|
||||
\"base_model_name\": \"id\"}`"""
|
||||
chat_template: Optional[str] = None
|
||||
"""The file path to the chat template, or the template in single-line form
|
||||
for the specified model."""
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format to render message content within a chat template.
|
||||
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to
|
||||
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
trust_request_chat_template: bool = False
|
||||
"""Whether to trust the chat template provided in the request. If False,
|
||||
the server will always use the chat template specified by `--chat-template`
|
||||
or the ones from tokenizer."""
|
||||
response_role: str = "assistant"
|
||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||
ssl_keyfile: Optional[str] = None
|
||||
"""The file path to the SSL key file."""
|
||||
ssl_certfile: Optional[str] = None
|
||||
"""The file path to the SSL cert file."""
|
||||
ssl_ca_certs: Optional[str] = None
|
||||
"""The CA certificates file."""
|
||||
enable_ssl_refresh: bool = False
|
||||
"""Refresh SSL Context when SSL certificate files change"""
|
||||
ssl_cert_reqs: int = int(ssl.CERT_NONE)
|
||||
"""Whether client certificate is required (see stdlib ssl module's)."""
|
||||
root_path: Optional[str] = None
|
||||
"""FastAPI root_path when app is behind a path based routing proxy."""
|
||||
middleware: list[str] = field(default_factory=lambda: [])
|
||||
"""Additional ASGI middleware to apply to the app. We accept multiple
|
||||
--middleware arguments. The value should be an import path. If a function
|
||||
is provided, vLLM will add it to the server using
|
||||
`@app.middleware('http')`. If a class is provided, vLLM will
|
||||
add it to the server using `app.add_middleware()`."""
|
||||
return_tokens_as_token_ids: bool = False
|
||||
"""When `--max-logprobs` is specified, represents single tokens as
|
||||
strings of the form 'token_id:{token_id}' so that tokens that are not
|
||||
JSON-encodable can be identified."""
|
||||
disable_frontend_multiprocessing: bool = False
|
||||
"""If specified, will run the OpenAI frontend server in the same process as
|
||||
the model serving engine."""
|
||||
enable_request_id_headers: bool = False
|
||||
"""If specified, API server will add X-Request-Id header to responses."""
|
||||
enable_auto_tool_choice: bool = False
|
||||
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
|
||||
to specify which parser to use."""
|
||||
exclude_tools_when_tool_choice_none: bool = False
|
||||
"""If specified, exclude tool definitions in prompts when
|
||||
tool_choice='none'."""
|
||||
tool_call_parser: Optional[str] = None
|
||||
"""Select the tool call parser depending on the model that you're using.
|
||||
This is used to parse the model-generated tool call into OpenAI API format.
|
||||
Required for `--enable-auto-tool-choice`. You can choose any option from
|
||||
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
|
||||
tool_parser_plugin: str = ""
|
||||
"""Special the tool parser plugin write to parse the model-generated tool
|
||||
into OpenAI API format, the name register in this plugin can be used in
|
||||
`--tool-call-parser`."""
|
||||
tool_server: Optional[str] = None
|
||||
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
|
||||
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
|
||||
purpose."""
|
||||
log_config_file: Optional[str] = envs.VLLM_LOGGING_CONFIG_PATH
|
||||
"""Path to logging config JSON file for both vllm and uvicorn"""
|
||||
max_log_len: Optional[int] = None
|
||||
"""Max number of prompt characters or prompt ID numbers being printed in
|
||||
log. The default of None means unlimited."""
|
||||
disable_fastapi_docs: bool = False
|
||||
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
|
||||
enable_prompt_tokens_details: bool = False
|
||||
"""If set to True, enable prompt_tokens_details in usage."""
|
||||
enable_server_load_tracking: bool = False
|
||||
"""If set to True, enable tracking server_load_metrics in the app state."""
|
||||
enable_force_include_usage: bool = False
|
||||
"""If set to True, including usage on every request."""
|
||||
enable_tokenizer_info_endpoint: bool = False
|
||||
"""Enable the /get_tokenizer_info endpoint. May expose chat
|
||||
templates and other tokenizer configuration."""
|
||||
enable_log_outputs: bool = False
|
||||
"""If True, log model outputs (generations).
|
||||
Requires --enable-log-requests."""
|
||||
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
|
||||
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
|
||||
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
|
||||
"""Maximum number of HTTP headers allowed in a request for h11 parser.
|
||||
Helps mitigate header abuse. Default: 256."""
|
||||
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
|
||||
"""If set to True, log the stack trace of error responses"""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
from vllm.engine.arg_utils import get_kwargs
|
||||
|
||||
frontend_kwargs = get_kwargs(FrontendArgs)
|
||||
|
||||
# Special case: allowed_origins, allowed_methods, allowed_headers all
|
||||
# need json.loads type
|
||||
# Should also remove nargs
|
||||
frontend_kwargs["allowed_origins"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_methods"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_headers"]["type"] = json.loads
|
||||
del frontend_kwargs["allowed_origins"]["nargs"]
|
||||
del frontend_kwargs["allowed_methods"]["nargs"]
|
||||
del frontend_kwargs["allowed_headers"]["nargs"]
|
||||
|
||||
# Special case: LoRA modules need custom parser action and
|
||||
# optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
|
||||
|
||||
# Special case: Middleware needs to append action
|
||||
frontend_kwargs["middleware"]["action"] = "append"
|
||||
frontend_kwargs["middleware"]["type"] = str
|
||||
if "nargs" in frontend_kwargs["middleware"]:
|
||||
del frontend_kwargs["middleware"]["nargs"]
|
||||
frontend_kwargs["middleware"]["default"] = []
|
||||
|
||||
# Special case: Tool call parser shows built-in options.
|
||||
valid_tool_parsers = list(ToolParserManager.tool_parsers.keys())
|
||||
parsers_str = ",".join(valid_tool_parsers)
|
||||
frontend_kwargs["tool_call_parser"]["metavar"] = (
|
||||
f"{{{parsers_str}}} or name registered in --tool-parser-plugin")
|
||||
|
||||
frontend_group = parser.add_argument_group(
|
||||
title="Frontend",
|
||||
description=FrontendArgs.__doc__,
|
||||
)
|
||||
|
||||
for key, value in frontend_kwargs.items():
|
||||
frontend_group.add_argument(f"--{key.replace('_', '-')}", **value)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Create the CLI argument parser used by the OpenAI API server.
|
||||
|
||||
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
|
||||
register all arguments instead of manually enumerating them here. This
|
||||
avoids code duplication and keeps the argument definitions in one place.
|
||||
"""
|
||||
parser.add_argument("model_tag",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="The model tag to serve "
|
||||
"(optional if specified in config)")
|
||||
parser.add_argument(
|
||||
"--headless",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run in headless mode. See multi-node data parallel "
|
||||
"documentation for more details.")
|
||||
parser.add_argument("--api-server-count",
|
||||
"-asc",
|
||||
type=int,
|
||||
default=1,
|
||||
help="How many API server processes to run.")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help="Read CLI options from a config file. "
|
||||
"Must be a YAML with the following options: "
|
||||
"https://docs.vllm.ai/en/latest/configuration/serve_args.html")
|
||||
parser = FrontendArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
if args.enable_log_outputs and not args.enable_log_requests:
|
||||
raise TypeError("Error: --enable-log-outputs requires "
|
||||
"--enable-log-requests")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server")
|
||||
return make_arg_parser(parser_for_docs)
|
||||
90
vllm/entrypoints/openai/logits_processors.py
Normal file
90
vllm/entrypoints/openai/logits_processors.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import lru_cache, partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
"""Logits processor for constraining generated tokens to a
|
||||
specific set of token ids."""
|
||||
|
||||
def __init__(self, allowed_ids: Iterable[int]):
|
||||
self.allowed_ids: Optional[list[int]] = list(allowed_ids)
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, token_ids: list[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.mask is None:
|
||||
self.mask = torch.ones((logits.shape[-1], ),
|
||||
dtype=torch.bool,
|
||||
device=logits.device)
|
||||
self.mask[self.allowed_ids] = False
|
||||
self.allowed_ids = None
|
||||
logits.masked_fill_(self.mask, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_allowed_token_ids_logits_processor(
|
||||
allowed_token_ids: frozenset[int],
|
||||
vocab_size: int,
|
||||
) -> LogitsProcessor:
|
||||
if not allowed_token_ids:
|
||||
raise ValueError("Empty allowed_token_ids provided")
|
||||
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains "
|
||||
"out-of-vocab token id")
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
logit_bias: dict[int, float],
|
||||
token_ids: list[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]],
|
||||
allowed_token_ids: Optional[list[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> list[LogitsProcessor]:
|
||||
logits_processors: list[LogitsProcessor] = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
clamped_logit_bias: dict[int, float] = {
|
||||
int(token_id): min(100.0, max(-100.0, bias))
|
||||
for token_id, bias in logit_bias.items()
|
||||
}
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Found token_id in logit_bias that is not "
|
||||
"an integer or string representing an integer") from exc
|
||||
|
||||
# Check if token_id is within the vocab size
|
||||
for token_id, bias in clamped_logit_bias.items():
|
||||
if token_id < 0 or token_id >= len(tokenizer):
|
||||
raise ValueError(f"token_id {token_id} in logit_bias contains "
|
||||
"out-of-vocab token id")
|
||||
|
||||
logits_processors.append(
|
||||
partial(logit_bias_logits_processor, clamped_logit_bias))
|
||||
|
||||
if allowed_token_ids is not None:
|
||||
logits_processors.append(
|
||||
_get_allowed_token_ids_logits_processor(
|
||||
frozenset(allowed_token_ids), len(tokenizer)))
|
||||
|
||||
return logits_processors
|
||||
2757
vllm/entrypoints/openai/protocol.py
Normal file
2757
vllm/entrypoints/openai/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
491
vllm/entrypoints/openai/run_batch.py
Normal file
491
vllm/entrypoints/openai/run_batch.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
from typing import Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
RerankResponse, ScoreResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help=
|
||||
"The path or url to a single input file. Currently supports local file "
|
||||
"paths, or the http protocol (http or https). If a URL is specified, "
|
||||
"the file should be available via HTTP GET.")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The path or url to a single output file. Currently supports "
|
||||
"local file paths, or web (http or https) urls. If a URL is specified,"
|
||||
" the file should be available via HTTP PUT.")
|
||||
parser.add_argument(
|
||||
"--output-tmp-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory to store the output file before uploading it "
|
||||
"to the output URL.",
|
||||
)
|
||||
parser.add_argument("--response-role",
|
||||
type=optional_type(str),
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=True`.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument("--enable-metrics",
|
||||
action="store_true",
|
||||
help="Enable Prometheus metrics")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="URL to the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port number for the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible batch runner.")
|
||||
return make_arg_parser(parser).parse_args()
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
class BatchProgressTracker:
|
||||
|
||||
def __init__(self):
|
||||
self._total = 0
|
||||
self._pbar: Optional[tqdm] = None
|
||||
|
||||
def submitted(self):
|
||||
self._total += 1
|
||||
|
||||
def completed(self):
|
||||
if self._pbar:
|
||||
self._pbar.update()
|
||||
|
||||
def pbar(self) -> tqdm:
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
self._pbar = tqdm(total=self._total,
|
||||
unit="req",
|
||||
desc="Running batch",
|
||||
mininterval=5,
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT)
|
||||
return self._pbar
|
||||
|
||||
|
||||
async def read_file(path_or_url: str) -> str:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, \
|
||||
session.get(path_or_url) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
with open(path_or_url, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def write_local_file(output_path: str,
|
||||
batch_outputs: list[BatchRequestOutput]) -> None:
|
||||
"""
|
||||
Write the responses to a local file.
|
||||
output_path: The path to write the responses to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
"""
|
||||
# We should make this async, but as long as run_batch runs as a
|
||||
# standalone program, blocking the event loop won't affect performance.
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=f)
|
||||
|
||||
|
||||
async def upload_data(output_url: str, data_or_file: str,
|
||||
from_file: bool) -> None:
|
||||
"""
|
||||
Upload a local file to a URL.
|
||||
output_url: The URL to upload the file to.
|
||||
data_or_file: Either the data to upload or the path to the file to upload.
|
||||
from_file: If True, data_or_file is the path to the file to upload.
|
||||
"""
|
||||
# Timeout is a common issue when uploading large files.
|
||||
# We retry max_retries times before giving up.
|
||||
max_retries = 5
|
||||
# Number of seconds to wait before retrying.
|
||||
delay = 5
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# We increase the timeout to 1000 seconds to allow
|
||||
# for large files (default is 300).
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||||
total=1000)) as session:
|
||||
if from_file:
|
||||
with open(data_or_file, "rb") as file:
|
||||
async with session.put(output_url,
|
||||
data=file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to upload file.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}")
|
||||
else:
|
||||
async with session.put(output_url,
|
||||
data=data_or_file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to upload data.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}")
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries:
|
||||
logger.error(
|
||||
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
|
||||
attempt,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
|
||||
) from e
|
||||
|
||||
|
||||
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput],
|
||||
output_tmp_dir: str) -> None:
|
||||
"""
|
||||
Write batch_outputs to a file or upload to a URL.
|
||||
path_or_url: The path or URL to write batch_outputs to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
output_tmp_dir: The directory to store the output file before uploading it
|
||||
to the output URL.
|
||||
"""
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
if output_tmp_dir is None:
|
||||
logger.info("Writing outputs to memory buffer")
|
||||
output_buffer = StringIO()
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=output_buffer)
|
||||
output_buffer.seek(0)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(
|
||||
path_or_url,
|
||||
output_buffer.read().strip().encode("utf-8"),
|
||||
from_file=False,
|
||||
)
|
||||
else:
|
||||
# Write responses to a temporary file and then upload it to the URL.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=output_tmp_dir,
|
||||
prefix="tmp_batch_output_",
|
||||
suffix=".jsonl",
|
||||
) as f:
|
||||
logger.info("Writing outputs to temporary local file %s",
|
||||
f.name)
|
||||
await write_local_file(f.name, batch_outputs)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(path_or_url, f.name, from_file=True)
|
||||
else:
|
||||
logger.info("Writing outputs to local file %s", path_or_url)
|
||||
await write_local_file(path_or_url, batch_outputs)
|
||||
|
||||
|
||||
def make_error_request_output(request: BatchRequestInput,
|
||||
error_msg: str) -> BatchRequestOutput:
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=error_msg,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
|
||||
async def make_async_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
|
||||
return make_error_request_output(request, error_msg)
|
||||
|
||||
|
||||
async def run_request(serving_engine_func: Callable,
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
|
||||
RerankResponse),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=response.error.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
batch_output = make_error_request_output(
|
||||
request, error_msg="Request must not be sent in stream mode")
|
||||
|
||||
tracker.completed()
|
||||
return batch_output
|
||||
|
||||
|
||||
async def run_batch(
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
)
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if "generate" in supported_tasks else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if "embed" in supported_tasks else None
|
||||
|
||||
enable_serving_reranking = ("classify" in supported_tasks and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
|
||||
openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if ("embed" in supported_tasks or enable_serving_reranking) else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: list[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
chat_handler_fn = openai_serving_chat.create_chat_completion if \
|
||||
openai_serving_chat is not None else None
|
||||
if chat_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=
|
||||
"The model does not support Chat Completions API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(chat_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
embed_handler_fn = openai_serving_embedding.create_embedding if \
|
||||
openai_serving_embedding is not None else None
|
||||
if embed_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Embeddings API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/score"):
|
||||
score_handler_fn = openai_serving_scores.create_score if \
|
||||
openai_serving_scores is not None else None
|
||||
if score_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Scores API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/rerank"):
|
||||
rerank_handler_fn = openai_serving_scores.do_rerank if \
|
||||
openai_serving_scores is not None else None
|
||||
if rerank_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Rerank API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(rerank_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /score, /rerank ."
|
||||
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||
"score/rerank versions.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
await write_file(args.output_file, responses, args.output_tmp_dir)
|
||||
|
||||
|
||||
async def main(args: Namespace):
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
|
||||
disable_frontend_multiprocessing=False,
|
||||
) as engine_client:
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
|
||||
await run_batch(engine_client, vllm_config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(main(args))
|
||||
1597
vllm/entrypoints/openai/serving_chat.py
Normal file
1597
vllm/entrypoints/openai/serving_chat.py
Normal file
File diff suppressed because it is too large
Load Diff
173
vllm/entrypoints/openai/serving_classification.py
Normal file
173
vllm/entrypoints/openai/serving_classification.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
ErrorResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationMixin(OpenAIServing):
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
if isinstance(ctx.request.input, str) and not ctx.request.input:
|
||||
return self.create_error_response(
|
||||
"Input cannot be empty for classification",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request))
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
probs = classify_res.probs
|
||||
predicted_index = int(np.argmax(probs))
|
||||
label = getattr(self.model_config.hf_config, "id2label",
|
||||
{}).get(predicted_index)
|
||||
|
||||
item = ClassificationData(
|
||||
index=idx,
|
||||
label=label,
|
||||
probs=probs,
|
||||
num_classes=len(probs),
|
||||
)
|
||||
|
||||
items.append(item)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ClassificationResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(self,
|
||||
request: ClassificationRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens)
|
||||
|
||||
|
||||
class ServingClassification(ClassificationMixin):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
model_name = self.models.model_name()
|
||||
request_id = (f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request)}")
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify("classify", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return pooling_params
|
||||
692
vllm/entrypoints/openai/serving_completion.py
Normal file
692
vllm/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,692 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import as_list, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info(
|
||||
"Using default completion sampling params from %s: %s",
|
||||
source,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"Echo is unsupported with prompt embeds.")
|
||||
|
||||
if (request.prompt_logprobs is not None
|
||||
and request.prompt_embeds is not None):
|
||||
return self.create_error_response(
|
||||
"prompt_logprobs is not compatible with prompt embeds.")
|
||||
|
||||
request_id = (
|
||||
f"cmpl-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
# Mypy does not infer that engine_prompt will have only one of
|
||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||
# these as Union[object, the expected type], where it infers
|
||||
# object if engine_prompt is a subclass of one of the
|
||||
# typeddicts that defines both keys. Worse, because of
|
||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||
# infer the type of engine_prompt correctly because of the
|
||||
# enumerate. So we need an unnecessary cast here.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if is_embeds_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_embeds"])
|
||||
elif is_tokens_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=input_length,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. Noting that best_of is only supported in V0. In addition,
|
||||
# we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata,
|
||||
enable_force_include_usage=self.enable_force_include_usage,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
enable_force_include_usage: bool,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
num_cached_tokens = None
|
||||
first_iteration = True
|
||||
|
||||
stream_options = request.stream_options
|
||||
if stream_options:
|
||||
include_usage = (stream_options.include_usage
|
||||
or enable_force_include_usage)
|
||||
include_continuous_usage = (include_usage and
|
||||
stream_options.continuous_usage_stats)
|
||||
else:
|
||||
include_usage, include_continuous_usage = False, False
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
|
||||
if first_iteration:
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
first_iteration = False
|
||||
|
||||
prompt_text = res.prompt
|
||||
if prompt_text is None:
|
||||
engine_prompt = engine_prompts[prompt_idx]
|
||||
prompt_text = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[dict[
|
||||
int, Logprob]]]]
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
|
||||
# Useful when request.return_token_ids is True
|
||||
# Returning prompt token IDs shares the same logic
|
||||
# with the echo implementation.
|
||||
prompt_token_ids_to_return: Optional[list[int]] = None
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and not has_echoed[i]:
|
||||
assert prompt_token_ids is not None
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
else:
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids,
|
||||
*output.token_ids,
|
||||
]
|
||||
out_logprobs = [
|
||||
*(prompt_logprobs or []),
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# has_echoed[i] is reused here to indicate whether
|
||||
# we have already returned the prompt token IDs.
|
||||
if not has_echoed[i]:
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
|
||||
if (not delta_text and not delta_token_ids
|
||||
and not previous_num_tokens[i]):
|
||||
# Chunked prefill case, don't return empty chunks
|
||||
continue
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
return_as_token_id=request.
|
||||
return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
prompt_token_ids=prompt_token_ids_to_return,
|
||||
token_ids=(as_list(output.token_ids) if
|
||||
request.return_token_ids else None),
|
||||
)
|
||||
],
|
||||
)
|
||||
if include_continuous_usage:
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
response_json = chunk.model_dump_json(exclude_unset=False)
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
total_prompt_tokens = sum(num_prompt_tokens)
|
||||
total_completion_tokens = sum(previous_num_tokens)
|
||||
final_usage_info = UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
)
|
||||
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens)
|
||||
|
||||
if include_usage:
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=final_usage_info,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_unset=False, exclude_none=True)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = final_usage_info
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
self,
|
||||
final_res_batch: list[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
kv_transfer_params = None
|
||||
last_final_res = None
|
||||
for final_res in final_res_batch:
|
||||
last_final_res = final_res
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[dict[int,
|
||||
Logprob]]]]
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo:
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
else:
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
prompt_token_ids=(prompt_token_ids
|
||||
if request.return_token_ids else None),
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
if (self.enable_prompt_tokens_details and last_final_res
|
||||
and last_final_res.num_cached_tokens):
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=last_final_res.num_cached_tokens)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
if final_res_batch:
|
||||
kv_transfer_params = final_res_batch[0].kv_transfer_params
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
return_as_token_id: Optional[bool] = None,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: list[int] = []
|
||||
out_token_logprobs: list[Optional[float]] = []
|
||||
out_tokens: list[str] = []
|
||||
out_top_logprobs: list[Optional[dict[str, float]]] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
should_return_as_token_id = (return_as_token_id
|
||||
if return_as_token_id is not None else
|
||||
self.return_tokens_as_token_ids)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if should_return_as_token_id:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(
|
||||
top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
})
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
max_input_length: Optional[int] = None,
|
||||
) -> RenderConfig:
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
|
||||
return RenderConfig(
|
||||
max_length=max_input_tokens_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo
|
||||
and not request.return_token_ids),
|
||||
)
|
||||
631
vllm/entrypoints/openai/serving_embedding.py
Normal file
631
vllm/entrypoints/openai/serving_embedding.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Final, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext,
|
||||
TextTokensPrompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingOutput, PoolingRequestOutput, RequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import chunk_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[list[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.embedding
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
class EmbeddingMixin(OpenAIServing):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
pooler_config = self.model_config.pooler_config
|
||||
|
||||
# Avoid repeated attribute lookups
|
||||
self.supports_chunked_processing = bool(
|
||||
pooler_config and pooler_config.enable_chunked_processing)
|
||||
self.max_embed_len = (pooler_config.max_embed_len if pooler_config
|
||||
and pooler_config.max_embed_len else None)
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template
|
||||
or ctx.chat_template,
|
||||
chat_template_content_format=ctx.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=ctx.request.add_generation_prompt,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request),
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_render_config(
|
||||
self, request: EmbeddingCompletionRequest) -> RenderConfig:
|
||||
# Set max_length based on chunked processing capability
|
||||
if self._should_use_chunked_processing(request):
|
||||
max_length = None
|
||||
else:
|
||||
max_length = self.max_embed_len or self.max_model_len
|
||||
|
||||
return RenderConfig(
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
items: list[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
embedding_res = EmbeddingRequestOutput.from_base(final_res)
|
||||
|
||||
item = EmbeddingResponseData(
|
||||
index=idx,
|
||||
embedding=_get_embedding(embedding_res.outputs,
|
||||
ctx.request.encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _get_max_position_embeddings(self) -> int:
|
||||
"""Get the model's effective maximum sequence length for chunking."""
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def _should_use_chunked_processing(self, request) -> bool:
|
||||
"""Check if chunked processing should be used for this request."""
|
||||
return isinstance(
|
||||
request,
|
||||
(EmbeddingCompletionRequest,
|
||||
EmbeddingChatRequest)) and self.supports_chunked_processing
|
||||
|
||||
async def _process_chunked_request(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
original_prompt: TextTokensPrompt,
|
||||
pooling_params,
|
||||
trace_headers,
|
||||
prompt_idx: int,
|
||||
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
|
||||
"""Process a single prompt using chunked processing."""
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
token_ids = original_prompt["prompt_token_ids"]
|
||||
|
||||
# Split into chunks using max_position_embeddings
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
# Process all chunks for MEAN aggregation
|
||||
for chunk_idx, chunk_tokens in enumerate(
|
||||
chunk_list(token_ids, max_pos_embeddings)):
|
||||
# Create a request ID for this chunk
|
||||
chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-"
|
||||
f"chunk-{chunk_idx}")
|
||||
|
||||
# Create engine prompt for this chunk
|
||||
chunk_engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=chunk_tokens)
|
||||
|
||||
# Create chunk request prompt for logging
|
||||
chunk_text = ""
|
||||
chunk_request_prompt = TextTokensPrompt(
|
||||
prompt=chunk_text, prompt_token_ids=chunk_tokens)
|
||||
|
||||
# Log the chunk
|
||||
self._log_inputs(chunk_request_id,
|
||||
chunk_request_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Create generator for this chunk and wrap it to return indices
|
||||
original_generator = self.engine_client.encode(
|
||||
chunk_engine_prompt,
|
||||
pooling_params,
|
||||
chunk_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(original_generator)
|
||||
|
||||
return generators
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
"""Override to support chunked processing for embedding requests."""
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingCompletionRequest, EmbeddingChatRequest)):
|
||||
# Check if chunked processing is enabled for pooling models
|
||||
enable_chunked = self._should_use_chunked_processing(request)
|
||||
|
||||
# Use max_position_embeddings for chunked processing decisions
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
# Determine the effective max length for validation
|
||||
if self.max_embed_len is not None:
|
||||
# Use max_embed_len for validation instead of max_model_len
|
||||
length_type = "maximum embedding input length"
|
||||
max_length_value = self.max_embed_len
|
||||
else:
|
||||
# Fall back to max_model_len validation (original behavior)
|
||||
length_type = "maximum context length"
|
||||
max_length_value = self.max_model_len
|
||||
|
||||
validation_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input.")
|
||||
|
||||
chunked_processing_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input "
|
||||
"or enable chunked processing.")
|
||||
|
||||
# Check if input exceeds max length
|
||||
if token_num > max_length_value:
|
||||
raise ValueError(
|
||||
validation_error_msg.format(
|
||||
length_type=length_type,
|
||||
max_length_value=max_length_value,
|
||||
token_num=token_num))
|
||||
|
||||
# Check for chunked processing
|
||||
# when exceeding max_position_embeddings
|
||||
if token_num > max_pos_embeddings:
|
||||
if enable_chunked:
|
||||
# Allow long inputs when chunked processing is enabled
|
||||
logger.info(
|
||||
"Input length %s exceeds max_position_embeddings "
|
||||
"%s, will use chunked processing", token_num,
|
||||
max_pos_embeddings)
|
||||
else:
|
||||
raise ValueError(
|
||||
chunked_processing_error_msg.format(
|
||||
length_type="maximum position embeddings length",
|
||||
max_length_value=max_pos_embeddings,
|
||||
token_num=token_num))
|
||||
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# For other request types, use the parent's implementation
|
||||
return super()._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _is_text_tokens_prompt(self, prompt) -> bool:
|
||||
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: EngineTokensPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Optional[Mapping[str, str]],
|
||||
prompt_index: int,
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
"""Create a generator for a single prompt using standard processing."""
|
||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Override to support chunked processing."""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
|
||||
# Check if we should use chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
# If no chunked processing needed, delegate to parent class
|
||||
if not use_chunked:
|
||||
return await super()._prepare_generators(ctx)
|
||||
|
||||
# Custom logic for chunked processing
|
||||
generators: list[AsyncGenerator[Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
# Verify and set the task for pooling params
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
# Check if this specific prompt needs chunked processing
|
||||
if self._is_text_tokens_prompt(engine_prompt):
|
||||
# Cast to TextTokensPrompt since we've verified
|
||||
# prompt_token_ids
|
||||
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
|
||||
if (len(text_tokens_prompt["prompt_token_ids"])
|
||||
> max_pos_embeddings):
|
||||
# Use chunked processing for this prompt
|
||||
chunk_generators = await self._process_chunked_request(
|
||||
ctx, text_tokens_prompt, pooling_params,
|
||||
trace_headers, i)
|
||||
generators.extend(chunk_generators)
|
||||
continue
|
||||
|
||||
# Normal processing for short prompts or non-token prompts
|
||||
generator = await self._create_single_prompt_generator(
|
||||
ctx, engine_prompt, pooling_params, trace_headers, i)
|
||||
generators.append(generator)
|
||||
|
||||
from vllm.utils import merge_async_iterators
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect and aggregate batch results
|
||||
with support for chunked processing.
|
||||
|
||||
For chunked requests, performs online aggregation to
|
||||
minimize memory usage.
|
||||
For regular requests, collects results normally.
|
||||
"""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
# Check if we used chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
if not use_chunked:
|
||||
return await super()._collect_batch(ctx=ctx)
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
prompt_aggregators: dict[int, dict[str, Any]] = {}
|
||||
short_prompts_results: dict[int, PoolingRequestOutput] = {}
|
||||
|
||||
async for result_idx, result in ctx.result_generator:
|
||||
if "-chunk-" in result.request_id:
|
||||
# Extract prompt_idx from chunked request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
prompt_idx = int(parts[parts.index("prompt") + 1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback: extract from result_idx if parsing fails
|
||||
prompt_idx = result_idx
|
||||
|
||||
# Initialize aggregator for this prompt if needed
|
||||
if prompt_idx not in prompt_aggregators:
|
||||
prompt_aggregators[prompt_idx] = {
|
||||
'weighted_sum': None,
|
||||
'total_weight': 0,
|
||||
'chunk_count': 0,
|
||||
'request_id': result.request_id.split("-chunk-")[0]
|
||||
}
|
||||
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
# MEAN pooling with online weighted averaging
|
||||
# Ensure result is PoolingRequestOutput
|
||||
# for embedding processing
|
||||
if not isinstance(result, PoolingRequestOutput):
|
||||
return self.create_error_response(
|
||||
f"Expected PoolingRequestOutput for "
|
||||
f"chunked embedding, got "
|
||||
f"{type(result).__name__}")
|
||||
|
||||
# Handle both PoolingOutput and
|
||||
# EmbeddingOutput types
|
||||
if hasattr(result.outputs, 'data'):
|
||||
# PoolingOutput case
|
||||
embedding_data = result.outputs.data
|
||||
elif hasattr(result.outputs, 'embedding'):
|
||||
# EmbeddingOutput case -
|
||||
# convert embedding list to tensor
|
||||
embedding_data = result.outputs.embedding
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Unsupported output type: "
|
||||
f"{type(result.outputs).__name__}")
|
||||
|
||||
if not isinstance(embedding_data, torch.Tensor):
|
||||
embedding_data = torch.tensor(embedding_data,
|
||||
dtype=torch.float32)
|
||||
|
||||
if result.prompt_token_ids is None:
|
||||
return self.create_error_response(
|
||||
"prompt_token_ids cannot be None for "
|
||||
"chunked processing")
|
||||
weight = len(result.prompt_token_ids)
|
||||
|
||||
weighted_embedding = embedding_data.to(
|
||||
dtype=torch.float32) * weight
|
||||
|
||||
if aggregator['weighted_sum'] is None:
|
||||
# First chunk
|
||||
aggregator['weighted_sum'] = weighted_embedding
|
||||
else:
|
||||
# Accumulate
|
||||
aggregator['weighted_sum'] += weighted_embedding
|
||||
|
||||
aggregator['total_weight'] += weight
|
||||
aggregator['chunk_count'] += 1
|
||||
else:
|
||||
# Non-chunked result - extract prompt_idx from request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
# Last part should be prompt index
|
||||
prompt_idx = int(parts[-1])
|
||||
except (ValueError, IndexError):
|
||||
prompt_idx = result_idx # Fallback to result_idx
|
||||
|
||||
short_prompts_results[prompt_idx] = cast(
|
||||
PoolingRequestOutput, result)
|
||||
|
||||
# Finalize aggregated results
|
||||
final_res_batch: list[Union[PoolingRequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
if prompt_idx in prompt_aggregators:
|
||||
# Finalize MEAN aggregation for this chunked prompt
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
weighted_sum = aggregator['weighted_sum']
|
||||
total_weight = aggregator['total_weight']
|
||||
|
||||
if (weighted_sum is not None
|
||||
and isinstance(weighted_sum, torch.Tensor)
|
||||
and isinstance(total_weight,
|
||||
(int, float)) and total_weight > 0):
|
||||
|
||||
# Compute final mean embedding
|
||||
final_embedding = weighted_sum / total_weight
|
||||
|
||||
# Create a PoolingRequestOutput
|
||||
# for the aggregated result
|
||||
pooling_output_data = PoolingOutput(
|
||||
data=final_embedding)
|
||||
|
||||
# Get original prompt token IDs for this prompt
|
||||
original_prompt = ctx.engine_prompts[prompt_idx]
|
||||
if not self._is_text_tokens_prompt(original_prompt):
|
||||
return self.create_error_response(
|
||||
f"Chunked prompt {prompt_idx} is not a "
|
||||
f"TextTokensPrompt")
|
||||
|
||||
original_token_ids = cast(
|
||||
TextTokensPrompt,
|
||||
original_prompt)["prompt_token_ids"]
|
||||
|
||||
pooling_request_output = PoolingRequestOutput(
|
||||
request_id=aggregator['request_id'],
|
||||
prompt_token_ids=original_token_ids,
|
||||
outputs=pooling_output_data,
|
||||
finished=True)
|
||||
|
||||
final_res_batch.append(pooling_request_output)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Failed to aggregate chunks "
|
||||
f"for prompt {prompt_idx}")
|
||||
elif prompt_idx in short_prompts_results:
|
||||
final_res_batch.append(
|
||||
cast(PoolingRequestOutput,
|
||||
short_prompts_results[prompt_idx]))
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Result not found for prompt {prompt_idx}")
|
||||
|
||||
ctx.final_res_batch = cast(
|
||||
list[Union[RequestOutput, PoolingRequestOutput]],
|
||||
final_res_batch)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
request_id_prefix = "embd"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||
|
||||
ctx = EmbeddingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext[EmbeddingRequest],
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return pooling_params
|
||||
992
vllm/entrypoints/openai/serving_engine.py
Normal file
992
vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,992 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.context import ConversationContext
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorInfo,
|
||||
ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
PoolingResponse, RerankRequest,
|
||||
ResponsesRequest, ScoreRequest,
|
||||
ScoreResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranslationRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
|
||||
RenderConfig)
|
||||
# yapf: enable
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
|
||||
MultiModalDataDict, MultiModalUUIDDict)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
||||
merge_async_iterators, random_uuid)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CompletionLikeRequest = Union[
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
RerankRequest,
|
||||
ClassificationRequest,
|
||||
ScoreRequest,
|
||||
TokenizeCompletionRequest,
|
||||
]
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
|
||||
AnyRequest = Union[
|
||||
CompletionLikeRequest,
|
||||
ChatLikeRequest,
|
||||
SpeechToTextRequest,
|
||||
ResponsesRequest,
|
||||
IOProcessorRequest,
|
||||
]
|
||||
|
||||
AnyResponse = Union[
|
||||
CompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
TranscriptionResponse,
|
||||
TokenizeResponse,
|
||||
PoolingResponse,
|
||||
ClassificationResponse,
|
||||
ScoreResponse,
|
||||
]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
|
||||
|
||||
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
|
||||
|
||||
|
||||
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
|
||||
and "prompt_embeds" in prompt)
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
class RequestProcessingMixin(BaseModel):
|
||||
"""
|
||||
Mixin for request processing,
|
||||
handling prompt preparation and engine input.
|
||||
"""
|
||||
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
||||
engine_prompts: Optional[list[EngineTokensPrompt]] = []
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ResponseGenerationMixin(BaseModel):
|
||||
"""
|
||||
Mixin for response generation,
|
||||
managing result generators and final batch results.
|
||||
"""
|
||||
|
||||
result_generator: Optional[AsyncGenerator[tuple[int, Union[
|
||||
RequestOutput, PoolingRequestOutput]], None]] = None
|
||||
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
|
||||
default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ServeContext(
|
||||
RequestProcessingMixin,
|
||||
ResponseGenerationMixin,
|
||||
BaseModel,
|
||||
Generic[RequestT],
|
||||
):
|
||||
# Shared across all requests
|
||||
request: RequestT
|
||||
raw_request: Optional[Request] = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = Field(default_factory=lambda: int(time.time()))
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
|
||||
# `protected_namespaces` resolves Pydantic v2's warning
|
||||
# on conflict with protected namespace "model_"
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
ClassificationServeContext = ServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
|
||||
chat_template: Optional[str] = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption
|
||||
|
||||
|
||||
# Used to resolve the Pydantic error related to
|
||||
# forward reference of MultiModalDataDict in TokensPrompt
|
||||
RequestProcessingMixin.model_rebuild()
|
||||
ServeContext.model_rebuild()
|
||||
ClassificationServeContext.model_rebuild()
|
||||
EmbeddingServeContext.model_rebuild()
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
so you can easily tell “this ID came from Embedding vs Classification.”
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.models = models
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self._async_tokenizer_pool: dict[AnyTokenizer,
|
||||
AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
||||
"""
|
||||
Get a Renderer instance with the provided tokenizer.
|
||||
Uses shared async tokenizer pool for efficiency.
|
||||
"""
|
||||
return CompletionRenderer(
|
||||
model_config=self.model_config,
|
||||
tokenizer=tokenizer,
|
||||
async_tokenizer_pool=self._async_tokenizer_pool)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: Any,
|
||||
) -> RenderConfig:
|
||||
"""
|
||||
Build and return a `RenderConfig` for an endpoint.
|
||||
|
||||
Used by the renderer to control how prompts are prepared
|
||||
(e.g., tokenization and length handling). Endpoints should
|
||||
implement this with logic appropriate to their request type.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||
"""
|
||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||
given tokenizer.
|
||||
"""
|
||||
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
||||
if async_tokenizer is None:
|
||||
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||
self._async_tokenizer_pool[tokenizer] = async_tokenizer
|
||||
return async_tokenizer
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override
|
||||
to prepare `ctx` (classification, embedding, etc.).
|
||||
"""
|
||||
return None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
"""
|
||||
Default response builder. Subclass may override this method
|
||||
to return the appropriate response object.
|
||||
"""
|
||||
return self.create_error_response("unimplemented endpoint")
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
|
||||
generation = self._pipeline(ctx)
|
||||
|
||||
async for response in generation:
|
||||
return response
|
||||
|
||||
return self.create_error_response("No response yielded from pipeline")
|
||||
|
||||
async def _pipeline(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
|
||||
"""Execute the request processing pipeline yielding responses."""
|
||||
if error := await self._check_model(ctx.request):
|
||||
yield error
|
||||
if error := self._validate_request(ctx):
|
||||
yield error
|
||||
|
||||
preprocess_ret = await self._preprocess(ctx)
|
||||
if isinstance(preprocess_ret, ErrorResponse):
|
||||
yield preprocess_ret
|
||||
|
||||
generators_ret = await self._prepare_generators(ctx)
|
||||
if isinstance(generators_ret, ErrorResponse):
|
||||
yield generators_ret
|
||||
|
||||
collect_ret = await self._collect_batch(ctx)
|
||||
if isinstance(collect_ret, ErrorResponse):
|
||||
yield collect_ret
|
||||
|
||||
yield self._build_response(ctx)
|
||||
|
||||
def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if (truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.max_model_len):
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
return None
|
||||
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
if not hasattr(ctx.request, "to_pooling_params"):
|
||||
return self.create_error_response(
|
||||
"Request type does not support pooling parameters")
|
||||
|
||||
return ctx.request.to_pooling_params()
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Schedule the request and get the result generator."""
|
||||
generators: list[AsyncGenerator[Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect batch results from the result generator."""
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[Optional[Union[RequestOutput,
|
||||
PoolingRequestOutput]]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
return self.create_error_response(
|
||||
"Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [
|
||||
res for res in final_res_batch if res is not None
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
) -> ErrorResponse:
|
||||
if self.log_error_stack:
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type is not None:
|
||||
traceback.print_exc()
|
||||
else:
|
||||
traceback.print_stack()
|
||||
return ErrorResponse(error=ErrorInfo(
|
||||
message=message, type=err_type, code=status_code.value))
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
) -> str:
|
||||
json_str = json.dumps(
|
||||
self.create_error_response(message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code).model_dump())
|
||||
return json_str
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
error_response = None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in self.models.lora_requests:
|
||||
return None
|
||||
if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
|
||||
(load_result := await self.models.resolve_lora(request.model))):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if (isinstance(load_result, ErrorResponse) and
|
||||
load_result.error.code == HTTPStatus.BAD_REQUEST.value):
|
||||
error_response = load_result
|
||||
|
||||
return error_response or self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyRequest) -> Optional[LoRARequest]:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
supports_default_mm_loras: bool = False,
|
||||
) -> Optional[LoRARequest]:
|
||||
if request.model in self.models.lora_requests:
|
||||
return self.models.lora_requests[request.model]
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
return default_mm_lora
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_message_types(self, request: AnyRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
for message in request.messages:
|
||||
if (isinstance(message, dict) and "content" in message
|
||||
and isinstance(message["content"], list)):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
async def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
"do_lower_case", False)):
|
||||
prompt = prompt.lower()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = await async_tokenizer(
|
||||
prompt, add_special_tokens=add_special_tokens)
|
||||
elif truncate_prompt_tokens < 0:
|
||||
# Negative means we cap at the model's max length
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=self.max_model_len,
|
||||
)
|
||||
else:
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
input_ids = encoded.input_ids
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
async def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt_ids: list[int],
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
) -> TextTokensPrompt:
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
elif truncate_prompt_tokens < 0:
|
||||
input_ids = prompt_ids[-self.max_model_len:]
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
if tokenizer is None:
|
||||
input_text = ""
|
||||
else:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
input_text = await async_tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
(
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ScoreRequest,
|
||||
RerankRequest,
|
||||
ClassificationRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
# since these requests don't generate tokens.
|
||||
if token_num > self.max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreRequest: "score",
|
||||
ClassificationRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request),
|
||||
"embedding generation")
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(
|
||||
request,
|
||||
(TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest),
|
||||
):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages.")
|
||||
|
||||
if (max_tokens is not None
|
||||
and token_num + max_tokens > self.max_model_len):
|
||||
raise ValueError(
|
||||
"'max_tokens' or 'max_completion_tokens' is too large: "
|
||||
f"{max_tokens}. This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens and your request has "
|
||||
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
|
||||
f" - {token_num}).")
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
async def _tokenize_prompt_input_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, list[int]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
A simpler implementation that tokenizes a single prompt input.
|
||||
"""
|
||||
async for result in self._tokenize_prompt_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
add_special_tokens=add_special_tokens,
|
||||
):
|
||||
return result
|
||||
raise ValueError("No results yielded from tokenization")
|
||||
|
||||
async def _tokenize_prompt_inputs_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> AsyncGenerator[TextTokensPrompt, None]:
|
||||
"""
|
||||
A simpler implementation that tokenizes multiple prompt inputs.
|
||||
"""
|
||||
for prompt in prompt_inputs:
|
||||
if isinstance(prompt, str):
|
||||
yield await self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
prompt=prompt,
|
||||
tokenizer=tokenizer,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield await self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: Union[ChatLikeRequest, ResponsesRequest],
|
||||
tokenizer: AnyTokenizer,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tool_dicts: Optional[list[dict[str, Any]]] = None,
|
||||
documents: Optional[list[dict[str, str]]] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Sequence[RequestPrompt],
|
||||
list[EngineTokensPrompt],
|
||||
]:
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tool_dicts,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
|
||||
messages,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
_chat_template_kwargs: dict[str, Any] = dict(
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=documents,
|
||||
)
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
request_prompt: Union[str, list[int]]
|
||||
|
||||
if tokenizer is None:
|
||||
request_prompt = "placeholder"
|
||||
elif isinstance(tokenizer, MistralTokenizer):
|
||||
request_prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
else:
|
||||
request_prompt = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
conversation=conversation,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
|
||||
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
|
||||
should_parse_tools = tool_parser is not None and (hasattr(
|
||||
request, "tool_choice") and request.tool_choice != "none")
|
||||
|
||||
if should_parse_tools:
|
||||
if not isinstance(request, ChatCompletionRequest):
|
||||
msg = "Tool usage is only supported for Chat Completions API"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
request = tool_parser(tokenizer).adjust_request( # type: ignore
|
||||
request=request)
|
||||
|
||||
if tokenizer is None:
|
||||
assert isinstance(request_prompt, str), (
|
||||
"Prompt has to be a string",
|
||||
"when the tokenizer is not initialised",
|
||||
)
|
||||
prompt_inputs = TextTokensPrompt(prompt=request_prompt,
|
||||
prompt_token_ids=[1])
|
||||
elif isinstance(request_prompt, str):
|
||||
prompt_inputs = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request_prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For MistralTokenizer
|
||||
assert is_list_of(request_prompt, int), (
|
||||
"Prompt has to be either a string or a list of token ids")
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(request_prompt),
|
||||
prompt_token_ids=request_prompt,
|
||||
)
|
||||
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
|
||||
return conversation, [request_prompt], [engine_prompt]
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
request_prompt: RequestPrompt,
|
||||
engine_prompt: EngineTokensPrompt,
|
||||
sampling_params: SamplingParams,
|
||||
context: ConversationContext,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
priority: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
orig_priority = priority
|
||||
while True:
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
request_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
async for res in generator:
|
||||
context.append_output(res)
|
||||
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||
yield context
|
||||
|
||||
if not context.need_builtin_tool_call():
|
||||
# The model did not ask for a tool call, so we're done.
|
||||
break
|
||||
|
||||
# Call the tool and update the context with the result.
|
||||
tool_output = await context.call_tool()
|
||||
context.append_output(tool_output)
|
||||
|
||||
# TODO: uncomment this and enable tool output streaming
|
||||
# yield context
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids.
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=prompt_token_ids)
|
||||
request_prompt = prompt_token_ids
|
||||
# Update the sampling params.
|
||||
sampling_params.max_tokens = self.max_model_len - len(
|
||||
prompt_token_ids)
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[RequestPrompt, PromptType],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
elif isinstance(inputs, list):
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = getattr(inputs, 'prompt', None)
|
||||
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Optional[Mapping[str, str]]:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(raw_request: Optional[Request],
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
default = default or random_uuid()
|
||||
if raw_request is None:
|
||||
return default
|
||||
|
||||
return raw_request.headers.get("X-Request-Id", default)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
return_as_token_id: bool = False,
|
||||
) -> str:
|
||||
if return_as_token_id:
|
||||
return f"token_id:{token_id}"
|
||||
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
def _is_model_supported(self, model_name: Optional[str]) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
|
||||
def clamp_prompt_logprobs(
|
||||
prompt_logprobs: Union[PromptLogprobs,
|
||||
None], ) -> Union[PromptLogprobs, None]:
|
||||
if prompt_logprobs is None:
|
||||
return prompt_logprobs
|
||||
|
||||
for logprob_dict in prompt_logprobs:
|
||||
if logprob_dict is None:
|
||||
continue
|
||||
for logprob_values in logprob_dict.values():
|
||||
if logprob_values.logprob == float("-inf"):
|
||||
logprob_values.logprob = -9999.0
|
||||
return prompt_logprobs
|
||||
288
vllm/entrypoints/openai/serving_models.py
Normal file
288
vllm/entrypoints/openai/serving_models.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIServingModels:
|
||||
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||
|
||||
Handles the routes:
|
||||
- /v1/models
|
||||
- /v1/load_lora_adapter
|
||||
- /v1/unload_lora_adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: list[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[list[LoRAModulePath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
|
||||
):
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name))
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
if self.static_lora_modules is None:
|
||||
return
|
||||
for lora in self.static_lora_modules:
|
||||
load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
|
||||
lora_name=lora.name)
|
||||
load_result = await self.load_lora_adapter(
|
||||
request=load_request, base_model_name=lora.base_model_name)
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.error.message)
|
||||
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
|
||||
"""Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora_request is not None:
|
||||
return lora_request.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all
|
||||
adapters"""
|
||||
model_cards = [
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests.values()
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoRAAdapterRequest,
|
||||
base_model_name: Optional[str] = None
|
||||
) -> Union[ErrorResponse, str]:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(
|
||||
request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_path = request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
lora_request = LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path)
|
||||
if base_model_name is not None and self.is_base_model(
|
||||
base_model_name):
|
||||
lora_request.base_model_name = base_model_name
|
||||
|
||||
# Validate that the adapter can be loaded into the engine
|
||||
# This will also pre-load it for incoming requests
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
except Exception as e:
|
||||
error_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if "No adapter found" in str(e):
|
||||
error_type = "NotFoundError"
|
||||
status_code = HTTPStatus.NOT_FOUND
|
||||
|
||||
return create_error_response(message=str(e),
|
||||
err_type=error_type,
|
||||
status_code=status_code)
|
||||
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info("Loaded new LoRA adapter: name '%s', path '%s'",
|
||||
lora_name, lora_path)
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(
|
||||
request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# Safe to delete now since we hold the lock
|
||||
del self.lora_requests[lora_name]
|
||||
logger.info("Removed LoRA adapter: name '%s'", lora_name)
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if request.lora_name in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been "
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self,
|
||||
request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if 'lora_name' is not provided return an error
|
||||
if not request.lora_name:
|
||||
return create_error_response(
|
||||
message=
|
||||
"'lora_name' needs to be provided to unload a LoRA adapter.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if request.lora_name not in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(
|
||||
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
if lora_name in self.lora_requests:
|
||||
return self.lora_requests[lora_name]
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(
|
||||
base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name, resolver.__class__.__name__)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.", lora_name,
|
||||
resolver.__class__.__name__, e)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(f"LoRA adapter '{lora_name}' was found "
|
||||
"but could not be loaded."),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(error=ErrorInfo(
|
||||
message=message, type=err_type, code=status_code.value))
|
||||
276
vllm/entrypoints/openai/serving_pooling.py
Normal file
276
vllm/entrypoints/openai/serving_pooling.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_data(
|
||||
output: PoolingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[list[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.data.tolist()
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
pt_float32 = output.data.to(dtype=torch.float32)
|
||||
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
|
||||
return base64.b64encode(pooling_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
io_processor_plugin = self.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = self.models.model_name()
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
is_io_processor_request = isinstance(request, IOProcessorRequest)
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id)
|
||||
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
# In pooling requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.input,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("encode", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
model_output=result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
return self.io_processor.output_to_response(output)
|
||||
|
||||
assert isinstance(request,
|
||||
(PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_pooling_response(
|
||||
final_res_batch_checked,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
request.encoding_format,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_pooling_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> PoolingResponse:
|
||||
items: list[PoolingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
item = PoolingResponseData(
|
||||
index=idx,
|
||||
data=_get_data(final_res.outputs, encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return PoolingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self, request: PoolingCompletionRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
1709
vllm/entrypoints/openai/serving_responses.py
Normal file
1709
vllm/entrypoints/openai/serving_responses.py
Normal file
File diff suppressed because it is too large
Load Diff
479
vllm/entrypoints/openai/serving_score.py
Normal file
479
vllm/entrypoints/openai/serving_score.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage,
|
||||
ScoreRequest, ScoreResponse,
|
||||
ScoreResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
_validate_score_input_lens,
|
||||
compress_token_type_ids,
|
||||
get_score_prompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
input_texts = texts_1 + texts_2
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts))
|
||||
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
|
||||
text_token_prompt = \
|
||||
self._validate_input(
|
||||
request,
|
||||
tok_result["input_ids"],
|
||||
input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"]))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: list[Optional[PoolingRequestOutput]] =\
|
||||
[None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_texts_1: list[PoolingRequestOutput] = []
|
||||
emb_texts_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(texts_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_1.append(emb)
|
||||
|
||||
for i in range(len(texts_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_2.append(emb)
|
||||
|
||||
if len(emb_texts_1) == 1:
|
||||
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
|
||||
embed_1=emb_texts_1,
|
||||
embed_2=emb_texts_2)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
def _preprocess_score(
|
||||
self,
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: Union[str, ScoreContentPartParam],
|
||||
data_2: Union[str, ScoreContentPartParam],
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=data_1,
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
self._validate_input(request, engine_prompt["prompt_token_ids"],
|
||||
full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
data_1: Union[list[str], list[ScoreContentPartParam]],
|
||||
data_2: Union[list[str], list[ScoreContentPartParam]],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
preprocess_async = make_async(self._preprocess_score,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
preprocessed_prompts = await asyncio.gather(
|
||||
*(preprocess_async(request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_1=t1,
|
||||
data_2=t2) for t1, t2 in input_pairs))
|
||||
|
||||
for full_prompt, engine_prompt in preprocessed_prompts:
|
||||
request_prompts.append(full_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
default_pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
default_pooling_params.verify("score", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=default_pooling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
|
||||
pooling_params = default_pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
pooling_params.extra_kwargs = {
|
||||
"compressed_token_type_ids": compressed
|
||||
}
|
||||
else:
|
||||
pooling_params = (default_pooling_params)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[
|
||||
Optional[PoolingRequestOutput]] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
data_1: Union[list[str], str, ScoreMultiModalParam],
|
||||
data_2: Union[list[str], str, ScoreMultiModalParam],
|
||||
request: Union[ScoreRequest, RerankRequest],
|
||||
request_id: str,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
if not self.model_config.is_multimodal_model and (isinstance(
|
||||
data_1, dict) or isinstance(data_2, dict)):
|
||||
raise ValueError(
|
||||
f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501
|
||||
)
|
||||
|
||||
if isinstance(data_1, str):
|
||||
data_1 = [data_1]
|
||||
elif isinstance(data_1, dict):
|
||||
data_1 = data_1.get("content") # type: ignore[assignment]
|
||||
|
||||
if isinstance(data_2, str):
|
||||
data_2 = [data_2]
|
||||
elif isinstance(data_2, dict):
|
||||
data_2 = data_2.get("content") # type: ignore[assignment]
|
||||
|
||||
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
|
||||
|
||||
if self.model_config.is_cross_encoder:
|
||||
return await self._cross_encoding_score(
|
||||
tokenizer=tokenizer,
|
||||
data_1=data_1, # type: ignore[arg-type]
|
||||
data_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
else:
|
||||
return await self._embedding_score(
|
||||
tokenizer=tokenizer,
|
||||
texts_1=data_1, # type: ignore[arg-type]
|
||||
texts_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[ScoreResponse, ErrorResponse]:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.text_1,
|
||||
request.text_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
self.models.model_name(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
top_n = request.top_n if request.top_n > 0 else (
|
||||
len(documents)
|
||||
if isinstance(documents, list) else len(documents["content"]))
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self.models.model_name(),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> ScoreResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: list[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: Union[list[str], ScoreMultiModalParam],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]) if isinstance(
|
||||
documents, list) else RerankDocument(
|
||||
multi_modal=documents["content"][idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
196
vllm/entrypoints/openai/serving_tokenization.py
Normal file
196
vllm/entrypoints/openai/serving_tokenization.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
TokenizerInfoResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
tool_dicts = (None if request.tools is None else
|
||||
[tool.model_dump() for tool in request.tools])
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
tool_dicts=tool_dicts,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.prompt,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
input_ids: list[int] = []
|
||||
for engine_prompt in engine_prompts:
|
||||
self._log_inputs(request_id,
|
||||
engine_prompt,
|
||||
params=None,
|
||||
lora_request=lora_request)
|
||||
|
||||
if isinstance(engine_prompt,
|
||||
dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
token_strs=token_strs,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len)
|
||||
|
||||
async def create_detokenize(
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
params=None,
|
||||
lora_request=lora_request)
|
||||
|
||||
prompt_input = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
|
||||
async def get_tokenizer_info(
|
||||
self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
|
||||
"""Get comprehensive tokenizer information."""
|
||||
try:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
|
||||
return TokenizerInfoResponse(**info)
|
||||
except Exception as e:
|
||||
return self.create_error_response(
|
||||
f"Failed to get tokenizer info: {str(e)}")
|
||||
|
||||
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
|
||||
return RenderConfig(add_special_tokens=request.add_special_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
tokenizer: AnyTokenizer
|
||||
chat_template: Optional[str]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the tokenizer configuration."""
|
||||
return self._get_tokenizer_config()
|
||||
|
||||
def _get_tokenizer_config(self) -> dict[str, Any]:
|
||||
"""Get tokenizer configuration directly from the tokenizer object."""
|
||||
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
|
||||
|
||||
# Remove file path fields
|
||||
config.pop("vocab_file", None)
|
||||
config.pop("merges_file", None)
|
||||
|
||||
config = self._make_json_serializable(config)
|
||||
config["tokenizer_class"] = type(self.tokenizer).__name__
|
||||
if self.chat_template:
|
||||
config["chat_template"] = self.chat_template
|
||||
return config
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert any non-JSON-serializable objects to serializable format."""
|
||||
if hasattr(obj, "content"):
|
||||
return obj.content
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
136
vllm/entrypoints/openai/serving_transcription.py
Normal file
136
vllm/entrypoints/openai/serving_transcription.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
|
||||
TranslationResponseStreamChoice, TranslationStreamResponse)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
"""Handles transcription requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="transcribe",
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||
ErrorResponse]:
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranscriptionResponse,
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="transcription.chunk",
|
||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||
stream_response_class=TranscriptionStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate",
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
async def create_translation(
|
||||
self, audio_data: bytes, request: TranslationRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranslationResponse,
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self, request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
388
vllm/entrypoints/openai/speech_to_text.py
Normal file
388
vllm/entrypoints/openai/speech_to_text.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import cached_property
|
||||
from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage, ErrorResponse, RequestResponseMetadata,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, TranslationResponse,
|
||||
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
SpeechToTextRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse]
|
||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
task_type: Literal["transcribe", "translate"] = "transcribe",
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
self.task_type = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
model_config, task_type)
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsTranscription], model_cls)
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
to_language = self.model_cls.validate_language(request.to_language) \
|
||||
if request.to_language else None
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (self.asr_config.allow_audio_chunking
|
||||
and duration > self.asr_config.max_audio_clip_s)
|
||||
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||
prompts = []
|
||||
for chunk in chunks:
|
||||
# The model has control over the construction, as long as it
|
||||
# returns a valid PromptType.
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, duration
|
||||
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: SpeechToTextRequest,
|
||||
raw_request: Request,
|
||||
response_class: type[T],
|
||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||
) -> Union[T, AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ['text', 'json']:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`")
|
||||
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support LoRA for "
|
||||
f"{self.task_type.title()}.")
|
||||
|
||||
prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
|
||||
None]]] = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
# It will not display special tokens like <|startoftranscript|>
|
||||
request.prompt,
|
||||
params=sampling_params,
|
||||
lora_request=None)
|
||||
|
||||
list_result_generator = [
|
||||
self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
) for prompt in prompts
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(request, list_result_generator,
|
||||
request_id, request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
text = ""
|
||||
for result_generator in list_result_generator:
|
||||
async for op in result_generator:
|
||||
text += op.outputs[0].text
|
||||
|
||||
if self.task_type == "transcribe":
|
||||
# add usage in TranscriptionResponse.
|
||||
usage = {
|
||||
"type": "duration",
|
||||
# rounded up as per openAI specs
|
||||
"seconds": int(math.ceil(duration_s)),
|
||||
}
|
||||
final_response = cast(T, response_class(text=text,
|
||||
usage=usage))
|
||||
else:
|
||||
# no usage in response for translation task
|
||||
final_response = cast(
|
||||
T, response_class(text=text)) # type: ignore[call-arg]
|
||||
|
||||
return final_response
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
|
||||
response_stream_choice_class: Union[
|
||||
type[TranscriptionResponseStreamChoice],
|
||||
type[TranslationResponseStreamChoice]],
|
||||
stream_response_class: Union[type[TranscriptionStreamResponse],
|
||||
type[TranslationStreamResponse]],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = request.stream_include_usage \
|
||||
if request.stream_include_usage else False
|
||||
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
|
||||
try:
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if audio_tokens := self.model_cls.get_num_audio_tokens(
|
||||
audio_duration_s, self.asr_config,
|
||||
self.model_config):
|
||||
num_prompt_tokens += audio_tokens
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = stream_response_class(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
|
||||
final_usage_chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def _split_audio(self, audio_data: np.ndarray,
|
||||
sample_rate: int) -> list[np.ndarray]:
|
||||
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
||||
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < audio_data.shape[-1]:
|
||||
if i + chunk_size >= audio_data.shape[-1]:
|
||||
# handle last chunk
|
||||
chunks.append(audio_data[..., i:])
|
||||
break
|
||||
|
||||
# Find the best split point in the overlap region
|
||||
search_start = i + chunk_size - overlap_size
|
||||
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||
split_point = self._find_split_point(audio_data, search_start,
|
||||
search_end)
|
||||
|
||||
# Extract chunk up to the split point
|
||||
chunks.append(audio_data[..., i:split_point])
|
||||
i = split_point
|
||||
return chunks
|
||||
|
||||
def _find_split_point(self, wav: np.ndarray, start_idx: int,
|
||||
end_idx: int) -> int:
|
||||
"""Find the best point to split audio by
|
||||
looking for silence or low amplitude.
|
||||
Args:
|
||||
wav: Audio tensor [1, T]
|
||||
start_idx: Start index of search region
|
||||
end_idx: End index of search region
|
||||
Returns:
|
||||
Index of best splitting point
|
||||
"""
|
||||
segment = wav[start_idx:end_idx]
|
||||
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
min_energy_window = self.asr_config.min_energy_split_window_size
|
||||
assert min_energy_window is not None
|
||||
for i in range(0, len(segment) - min_energy_window, min_energy_window):
|
||||
window = segment[i:i + min_energy_window]
|
||||
energy = (window**2).mean()**0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
min_energy = energy
|
||||
return quietest_idx
|
||||
55
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
55
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
|
||||
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||
from .granite_tool_parser import GraniteToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .jamba_tool_parser import JambaToolParser
|
||||
from .kimi_k2_tool_parser import KimiK2ToolParser
|
||||
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .longcat_tool_parser import LongcatFlashToolParser
|
||||
from .minimax_tool_parser import MinimaxToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .openai_tool_parser import OpenAIToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||
from .qwen3xml_tool_parser import Qwen3XMLToolParser
|
||||
from .seed_oss_tool_parser import SeedOssToolParser
|
||||
from .step3_tool_parser import Step3ToolParser
|
||||
from .xlam_tool_parser import xLAMToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser",
|
||||
"ToolParserManager",
|
||||
"Granite20bFCToolParser",
|
||||
"GraniteToolParser",
|
||||
"Hermes2ProToolParser",
|
||||
"MistralToolParser",
|
||||
"Internlm2ToolParser",
|
||||
"Llama3JsonToolParser",
|
||||
"JambaToolParser",
|
||||
"Llama4PythonicToolParser",
|
||||
"LongcatFlashToolParser",
|
||||
"PythonicToolParser",
|
||||
"Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser",
|
||||
"DeepSeekV31ToolParser",
|
||||
"xLAMToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
"HunyuanA13BToolParser",
|
||||
"Glm4MoeModelToolParser",
|
||||
"Qwen3CoderToolParser",
|
||||
"Qwen3XMLToolParser",
|
||||
"SeedOssToolParser",
|
||||
"Step3ToolParser",
|
||||
"OpenAIToolParser",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
164
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
164
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from functools import cached_property
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import import_from_path, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""
|
||||
Abstract ToolParser class that should not be used directly. Provided
|
||||
properties and methods should be used in
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
Used for non-streaming responses where we have the entire model response
|
||||
available before sending to the client.
|
||||
Static because it's stateless.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls has not been implemented!")
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
from an incomplete response; for use when handling tool calls and
|
||||
streaming. Has to be an instance method because it requires state -
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||
"implemented!")
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
tool_parsers: dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name) -> type:
|
||||
"""
|
||||
Get tool parser by name which is registered by `register_module`.
|
||||
|
||||
Raise a KeyError exception if the name is not registered.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
module: type,
|
||||
module_name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True) -> None:
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(
|
||||
f'module must be subclass of ToolParser, but got {type(module)}'
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed_module = cls.tool_parsers[name]
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'at {existed_module.__module__}')
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[type, None] = None) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)
|
||||
or is_list_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be None, an instance of str, or a sequence of str, '
|
||||
f'but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user-defined tool parser by the path of the tool parser define
|
||||
file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
|
||||
try:
|
||||
import_from_path(module_name, plugin_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to load module '%s' from %s.",
|
||||
module_name, plugin_path)
|
||||
return
|
||||
367
vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py
Normal file
367
vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("deepseek_v31")
|
||||
class DeepSeekV31ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool▁calls▁begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool▁calls▁end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool▁call▁begin|>"
|
||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)")
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<function_name>.*)<|tool▁sep|>")
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"DeepSeek-V3.1 Tool parser could not locate tool call "
|
||||
"start/end tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(
|
||||
model_output)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
function_name, function_args = match
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args),
|
||||
))
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_calls_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_name, tool_args = current_tool_call_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_name = current_tool_call_name_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
370
vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py
Normal file
370
vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("deepseek_v3")
|
||||
class DeepSeekV3ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool▁calls▁begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool▁calls▁end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool▁call▁begin|>"
|
||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*[^\n`])"
|
||||
)
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n")
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"DeepSeek-V3 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(
|
||||
model_output)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
tool_type, function_name, function_args = match
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type=tool_type,
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args),
|
||||
))
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_calls_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_type, tool_name, tool_args = (
|
||||
current_tool_call_matches.groups())
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_type, tool_name = (
|
||||
current_tool_call_name_matches.groups())
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
185
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
Normal file
185
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("glm45")
|
||||
class Glm4MoeModelToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
|
||||
self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>",
|
||||
re.DOTALL)
|
||||
self.func_detail_regex = re.compile(
|
||||
r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL)
|
||||
self.func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL)
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self._buffer = ""
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str, arg_name: str,
|
||||
tools: Optional[list[ChatCompletionToolsParam]]) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if tool.function.name == tool_name:
|
||||
if tool.function.parameters is None:
|
||||
return False
|
||||
arg_type = tool.function.parameters.get(
|
||||
"properties", {}).get(arg_name, {}).get("type", None)
|
||||
return arg_type == "string"
|
||||
logger.warning("No tool named '%s'.", tool_name)
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
matched_tool_calls = self.func_call_regex.findall(model_output)
|
||||
logger.debug("model_output: %s", model_output)
|
||||
try:
|
||||
tool_calls = []
|
||||
for match in matched_tool_calls:
|
||||
tc_detail = self.func_detail_regex.search(match)
|
||||
tc_name = tc_detail.group(1)
|
||||
tc_args = tc_detail.group(2)
|
||||
pairs = self.func_arg_regex.findall(tc_args)
|
||||
arg_dct = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = value.strip()
|
||||
if not _is_string_type(tc_name, arg_key, request.tools):
|
||||
arg_val = _deserialize(arg_val)
|
||||
logger.debug("arg_key = %s, arg_val = %s", arg_key,
|
||||
arg_val)
|
||||
arg_dct[arg_key] = arg_val
|
||||
tool_calls.append(
|
||||
ToolCall(type="function",
|
||||
function=FunctionCall(
|
||||
name=tc_name, arguments=json.dumps(arg_dct))))
|
||||
except Exception:
|
||||
logger.exception("Failed to extract tool call spec")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
else:
|
||||
if len(tool_calls) > 0:
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
self._buffer += delta_text
|
||||
cur_text = self._buffer
|
||||
start_idx = cur_text.find(self.tool_call_start_token)
|
||||
if start_idx == -1:
|
||||
self._buffer = ""
|
||||
if self.current_tool_id > 0:
|
||||
cur_text = ""
|
||||
return DeltaMessage(content=cur_text)
|
||||
logger.debug("cur_text = %s", cur_text)
|
||||
end_idx = cur_text.find(self.tool_call_end_token)
|
||||
if end_idx != -1:
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = []
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
extracted_tool_calls = self.extract_tool_calls(
|
||||
cur_text[:end_idx + len(self.tool_call_end_token)], request)
|
||||
|
||||
if len(extracted_tool_calls.tool_calls) == 0:
|
||||
logger.warning("Failed to extract any tool calls.")
|
||||
return None
|
||||
tool_call = extracted_tool_calls.tool_calls[0]
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments)
|
||||
}
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = tool_call.function.arguments
|
||||
delta = DeltaMessage(
|
||||
content=extracted_tool_calls.content,
|
||||
tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments))
|
||||
])
|
||||
self.current_tool_id += 1
|
||||
self._buffer = cur_text[end_idx + len(self.tool_call_end_token):]
|
||||
return delta
|
||||
|
||||
self._buffer = cur_text[start_idx:]
|
||||
return DeltaMessage(content=cur_text[:start_idx])
|
||||
@@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecoder
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
||||
find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("granite-20b-fc")
|
||||
class Granite20bFCToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for the granite-20b-functioncalling model intended
|
||||
for use with the examples/tool_chat_template_granite20b_fc.jinja
|
||||
template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.bot_token = "<function_call>"
|
||||
self.tool_start_token = self.bot_token
|
||||
self.tool_call_regex = re.compile(r"<function_call>\s*")
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
if self.tool_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
dec = JSONDecoder()
|
||||
try:
|
||||
matches = list(self.tool_call_regex.finditer(model_output))
|
||||
logger.debug("Found %d tool call matches", len(matches))
|
||||
|
||||
raw_function_calls = []
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
# position after the <function_call> tag
|
||||
start_of_json = match.end()
|
||||
# end_index == the start of the next function call
|
||||
# (if exists)
|
||||
next_function_call_start = (matches[i + 1].start() if i +
|
||||
1 < len(matches) else None)
|
||||
|
||||
raw_function_calls.append(
|
||||
dec.raw_decode(
|
||||
model_output[start_of_json:next_function_call_start])
|
||||
[0])
|
||||
|
||||
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False),
|
||||
),
|
||||
) for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.find(self.bot_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s", e)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if len(current_text) < len(
|
||||
self.bot_token) and self.bot_token.startswith(current_text):
|
||||
return None
|
||||
|
||||
if not current_text.startswith(self.bot_token):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
start_idx = len(self.bot_token)
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
|
||||
while start_idx < len(current_text):
|
||||
(obj,
|
||||
end_idx) = partial_json_loads(current_text[start_idx:],
|
||||
flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx:start_idx +
|
||||
end_idx]))
|
||||
start_idx += end_idx
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
start_idx += len(self.bot_token)
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
237
vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
Normal file
237
vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
||||
find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("granite")
|
||||
class GraniteToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for the granite 3.0 models. Intended
|
||||
for use with the examples/tool_chat_template_granite.jinja
|
||||
template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser granite
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
# for granite 3.0, the token `<|tool_call|>`
|
||||
self.bot_token = "<|tool_call|>"
|
||||
# for granite 3.1, the string `<tool_call>`
|
||||
self.bot_string = "<tool_call>"
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
stripped = model_output.strip()\
|
||||
.removeprefix(self.bot_token)\
|
||||
.removeprefix(self.bot_string)\
|
||||
.lstrip()
|
||||
if not stripped or stripped[0] != '[':
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
try:
|
||||
raw_function_calls = json.loads(stripped)
|
||||
if not isinstance(raw_function_calls, list):
|
||||
raise Exception(
|
||||
f"Expected dict or list, got {type(raw_function_calls)}")
|
||||
|
||||
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False),
|
||||
),
|
||||
) for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s", e)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
start_idx = consume_space(0, current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_token):
|
||||
start_idx = consume_space(start_idx + len(self.bot_token),
|
||||
current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_string):
|
||||
start_idx = consume_space(start_idx + len(self.bot_string),
|
||||
current_text)
|
||||
if not current_text or start_idx >= len(current_text)\
|
||||
or current_text[start_idx] != '[':
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = None
|
||||
is_complete = None
|
||||
try:
|
||||
tool_calls, end_idx = partial_json_loads(
|
||||
current_text[start_idx:], flags)
|
||||
if type(tool_calls) is list:
|
||||
tool_call_arr = tool_calls
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
is_complete = [True] * len(tool_calls)
|
||||
if not is_complete_json(
|
||||
current_text[start_idx:start_idx + end_idx]):
|
||||
is_complete[-1] = False
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if not tool_call_arr:
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id]
|
||||
|
||||
delta = None
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
if len(tool_call_arr) > self.current_tool_id + 1:
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
455
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
455
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
@@ -0,0 +1,455 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("hermes")
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.error(
|
||||
"Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_start_token, add_special_tokens=False)
|
||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_end_token, add_special_tokens=False)
|
||||
|
||||
self.tool_call_start_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_start_token_ids
|
||||
]
|
||||
|
||||
self.tool_call_end_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_end_token_ids
|
||||
]
|
||||
|
||||
self.buffered_delta_text = ""
|
||||
|
||||
# Very simple idea: when encountering tokens like <, tool, _call, >,
|
||||
# <, /, tool, _call, >, store them in a buffer.
|
||||
# When the last token is encountered, empty the buffer and return it.
|
||||
# If a token appears in an incorrect sequence while storing in the buffer,
|
||||
# return the preceding buffer along with the token.
|
||||
def tool_call_delta_buffer(self, delta_text: str):
|
||||
# If the sequence of tool_call_start or tool_call_end tokens is not yet
|
||||
# complete, fill the buffer with the token and return "".
|
||||
if (delta_text in self.tool_call_start_token_array
|
||||
or delta_text in self.tool_call_end_token_array):
|
||||
# If delta_text is the last token of tool_call_start_token or
|
||||
# tool_call_end_token, empty the buffer and return
|
||||
# the buffered text + delta_text.
|
||||
if (delta_text == self.tool_call_start_token_array[-1]
|
||||
or delta_text == self.tool_call_end_token_array[-1]):
|
||||
buffered_text = self.buffered_delta_text
|
||||
self.buffered_delta_text = ""
|
||||
return buffered_text + delta_text
|
||||
else:
|
||||
self.buffered_delta_text = self.buffered_delta_text + delta_text
|
||||
return ""
|
||||
else:
|
||||
if self.buffered_delta_text:
|
||||
buffered_text = self.buffered_delta_text
|
||||
self.buffered_delta_text = ""
|
||||
return buffered_text + delta_text
|
||||
else:
|
||||
return delta_text
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because the tool_call tokens are
|
||||
# marked "special" in some models. Since they are skipped
|
||||
# prior to the call to the tool parser, it breaks tool calling.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = (
|
||||
self.tool_call_regex.findall(model_output))
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = [
|
||||
json.loads(match[0] if match[0] else match[1])
|
||||
for match in function_call_tuples
|
||||
]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False)))
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
# 1. All tokens are parsed based on _text, not token_ids.
|
||||
# 2. All incoming text data is processed by the tool_call_delta_buffer
|
||||
# function for buffering before being used for parsing.
|
||||
|
||||
delta_text = self.tool_call_delta_buffer(delta_text)
|
||||
# If the last characters of previous_text
|
||||
# match self.buffered_delta_text, remove only the matching part.
|
||||
if (len(previous_text) >= len(self.buffered_delta_text)
|
||||
and previous_text[-len(self.buffered_delta_text):]
|
||||
== self.buffered_delta_text):
|
||||
previous_text = previous_text[:-len(self.buffered_delta_text)]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token not in current_text:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_text.count(
|
||||
self.tool_call_start_token)
|
||||
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
|
||||
cur_tool_start_count = current_text.count(
|
||||
self.tool_call_start_token)
|
||||
cur_tool_end_count = current_text.count(self.tool_call_end_token)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case: if tool open & close tag counts don't match, we're doing
|
||||
# imaginary "else" block here
|
||||
# something with tools with this diff.
|
||||
# flags for partial JSON parting. exported constants from
|
||||
# "Allow" are handled via BIT MASK
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if (self.prev_tool_call_arr is None
|
||||
or len(self.prev_tool_call_arr) == 0):
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = diff.encode('utf-8').decode(
|
||||
'unicode_escape') if diff is str else diff
|
||||
if ('"}' not in delta_text):
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s", diff)
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
try:
|
||||
|
||||
current_tool_call = partial_json_parser.loads(
|
||||
tool_call_portion or "{}",
|
||||
flags) if tool_call_portion else None
|
||||
logger.debug("Parsed tool call %s", current_tool_call)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.debug("unable to parse JSON")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if (current_tool_call is None):
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
else:
|
||||
return None
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = DeltaMessage(content=delta_text) \
|
||||
if text_portion is not None else None
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = (
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
# extract the content after {"name": ..., "arguments":
|
||||
# directly from tool_call_portion as cur_arguments_json,
|
||||
# since cur_arguments may differ from the original text
|
||||
# due to partial JSON parsing
|
||||
# for example, tool_call_portion =
|
||||
# {"name": "search", "arguments": {"search_request": {"
|
||||
# but cur_arguments =
|
||||
# {"search_request": {}}
|
||||
function_name = current_tool_call.get("name")
|
||||
match = re.search(
|
||||
r'\{"name":\s*"' +
|
||||
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
|
||||
tool_call_portion.strip(), re.DOTALL)
|
||||
if match:
|
||||
cur_arguments_json = match.group(1)
|
||||
else:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
|
||||
logger.debug("finding %s in %s", delta_text,
|
||||
cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current.
|
||||
if (delta_text not in cur_arguments_json):
|
||||
return None
|
||||
args_delta_start_loc = cur_arguments_json. \
|
||||
rindex(delta_text) + \
|
||||
len(delta_text)
|
||||
|
||||
# use that to find the actual delta
|
||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= arguments_delta
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
# judge whether the tool_call_portion is a complete JSON
|
||||
try:
|
||||
json.loads(tool_call_portion)
|
||||
is_complete_json = True
|
||||
except Exception:
|
||||
is_complete_json = False
|
||||
|
||||
# if the delta_text ends with a '}' and tool_call_portion is a
|
||||
# complete JSON, then the last '}' does not belong to the
|
||||
# arguments, so we should trim it off
|
||||
if isinstance(delta_text, str) \
|
||||
and len(delta_text.rstrip()) >= 1 \
|
||||
and delta_text.rstrip()[-1] == '}' \
|
||||
and is_complete_json:
|
||||
delta_text = delta_text.rstrip()[:-1]
|
||||
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_text).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= delta_text
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = \
|
||||
current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
372
vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py
Normal file
372
vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501, SIM102
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import consume_space
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("hunyuan_a13b")
|
||||
class HunyuanA13BToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize state for streaming mode
|
||||
self.prev_tool_calls: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args: list[str] = [
|
||||
] # Track arguments sent for each tool
|
||||
|
||||
# For backward compatibility with tests
|
||||
self.current_tools_sent: list[bool] = []
|
||||
|
||||
# For backward compatibility with serving code
|
||||
self.prev_tool_call_arr = []
|
||||
|
||||
# Regex patterns for preprocessing
|
||||
self.answer_tool_calls_pattern = re.compile(
|
||||
r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL)
|
||||
|
||||
self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"')
|
||||
|
||||
self.tool_empty_arg_reg = re.compile(
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
|
||||
|
||||
# TODO: not support nested json object in fc arguments.
|
||||
self.tool_non_empty_arg_reg = re.compile(
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||
)
|
||||
|
||||
self.bot_string = "<tool_calls>"
|
||||
|
||||
# Define streaming state type to be initialized later
|
||||
self.streaming_state: dict[str, Any] = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [],
|
||||
}
|
||||
|
||||
def preprocess_model_output(
|
||||
self, model_output: str) -> tuple[Optional[str], Optional[str]]:
|
||||
# find the location tool call
|
||||
for match in self.answer_tool_calls_pattern.finditer(model_output):
|
||||
start, end = match.span()
|
||||
# check tool_calls whether in side of <think>
|
||||
think_regions = [(m.start(), m.end()) for m in re.finditer(
|
||||
r"<think>(.*?)</think>", model_output, flags=re.DOTALL)]
|
||||
in_think = any(start > t_start and end < t_end
|
||||
for t_start, t_end in think_regions)
|
||||
if not in_think:
|
||||
content = model_output[:start]
|
||||
tool_calls_content = match.group(1).strip()
|
||||
try:
|
||||
json.loads(tool_calls_content)
|
||||
return content, tool_calls_content
|
||||
except Exception:
|
||||
continue
|
||||
return model_output, None
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract tool calls from a complete model output.
|
||||
"""
|
||||
try:
|
||||
# Preprocess the model output
|
||||
content, potential_tool_calls = self.preprocess_model_output(
|
||||
model_output)
|
||||
|
||||
if not potential_tool_calls:
|
||||
# some text should be filtered out for no function call
|
||||
# this text is in a13b's chat template.
|
||||
if content:
|
||||
content = content.replace("助手:", "", 1)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=content)
|
||||
|
||||
# Parse the potential tool calls as JSON
|
||||
tool_calls_data = json.loads(potential_tool_calls)
|
||||
|
||||
# Ensure it's an array
|
||||
if not isinstance(tool_calls_data, list):
|
||||
logger.debug("Tool calls data is not an array")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=content or model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
for idx, call in enumerate(tool_calls_data):
|
||||
if (not isinstance(call, dict) or "name" not in call
|
||||
or "arguments" not in call):
|
||||
continue
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{random_uuid()}",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=call["name"],
|
||||
arguments=(json.dumps(call["arguments"]) if isinstance(
|
||||
call["arguments"], dict) else call["arguments"]),
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if not content or len(content.strip()) == 0:
|
||||
# clear the whitespace content.
|
||||
content = None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract tool calls for streaming mode.
|
||||
"""
|
||||
|
||||
start_idx = consume_space(0, current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_string):
|
||||
start_idx = consume_space(start_idx + len(self.bot_string),
|
||||
current_text)
|
||||
if not current_text or start_idx >= len(
|
||||
current_text) or current_text[start_idx] != '[':
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
self._try_parse_json_tools(current_text[start_idx:])
|
||||
|
||||
test_delta = self._handle_test_compatibility(current_text)
|
||||
if test_delta:
|
||||
return test_delta
|
||||
|
||||
name_matches = list(self.tool_name_reg.finditer(current_text))
|
||||
tool_count = len(name_matches)
|
||||
if tool_count == 0:
|
||||
return None
|
||||
self._ensure_state_arrays(tool_count)
|
||||
current_idx = self.streaming_state["current_tool_index"]
|
||||
|
||||
name_delta = self._handle_tool_name_streaming(current_idx, tool_count,
|
||||
name_matches)
|
||||
if name_delta:
|
||||
return name_delta
|
||||
|
||||
args_delta = self._handle_tool_args_streaming(current_text,
|
||||
current_idx, tool_count)
|
||||
if args_delta:
|
||||
return args_delta
|
||||
|
||||
return None
|
||||
|
||||
def _try_parse_json_tools(self, current_text: str):
|
||||
try:
|
||||
parsed_tools = json.loads(current_text)
|
||||
if isinstance(parsed_tools, list):
|
||||
self.prev_tool_call_arr = parsed_tools
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
def _handle_test_compatibility(self, current_text: str):
|
||||
if len(self.current_tools_sent) > 0:
|
||||
if (len(self.current_tools_sent) == 1
|
||||
and self.current_tools_sent[0] is False):
|
||||
name_match = self.tool_name_reg.search(current_text)
|
||||
if name_match:
|
||||
function_name = name_match.group(1)
|
||||
tool_id = f"chatcmpl-tool-{random_uuid()}"
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.current_tools_sent = [True]
|
||||
self.current_tool_id = 0
|
||||
self.streaming_state["current_tool_index"] = 0
|
||||
if len(self.streaming_state["sent_tools"]) == 0:
|
||||
self.streaming_state["sent_tools"].append({
|
||||
"sent_name":
|
||||
True,
|
||||
"sent_arguments_prefix":
|
||||
False,
|
||||
"sent_arguments":
|
||||
"",
|
||||
})
|
||||
else:
|
||||
self.streaming_state["sent_tools"][0][
|
||||
"sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
return delta
|
||||
return None
|
||||
|
||||
def _ensure_state_arrays(self, tool_count: int):
|
||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||
self.streaming_state["sent_tools"].append({
|
||||
"sent_name": False,
|
||||
"sent_arguments_prefix": False,
|
||||
"sent_arguments": "",
|
||||
})
|
||||
while len(self.streaming_state["tool_ids"]) < tool_count:
|
||||
self.streaming_state["tool_ids"].append(None)
|
||||
|
||||
def _handle_tool_name_streaming(self, current_idx: int, tool_count: int,
|
||||
name_matches):
|
||||
if current_idx == -1 or current_idx < tool_count - 1:
|
||||
next_idx = current_idx + 1
|
||||
if (next_idx < tool_count
|
||||
and not self.streaming_state["sent_tools"][next_idx]
|
||||
["sent_name"]):
|
||||
self.streaming_state["current_tool_index"] = next_idx
|
||||
self.current_tool_id = next_idx
|
||||
current_idx = next_idx
|
||||
tool_name = name_matches[current_idx].group(1)
|
||||
tool_id = f"call_{current_idx}_{random_uuid()}"
|
||||
self.streaming_state["tool_ids"][current_idx] = tool_id
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(name=tool_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
return delta
|
||||
return None
|
||||
|
||||
def _handle_tool_args_streaming(self, current_text: str, current_idx: int,
|
||||
tool_count: int):
|
||||
|
||||
if current_idx >= 0 and current_idx < tool_count:
|
||||
empty_args_match = self.tool_empty_arg_reg.search(current_text)
|
||||
if empty_args_match and empty_args_match.start() > 0:
|
||||
for i in range(tool_count):
|
||||
if i == current_idx:
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"]:
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"] = "{}"
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{}"
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{}").model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"]
|
||||
return delta
|
||||
|
||||
args_matches = list(
|
||||
self.tool_non_empty_arg_reg.finditer(current_text))
|
||||
if current_idx < len(args_matches):
|
||||
args_text = args_matches[current_idx].group(1)
|
||||
is_last_tool = current_idx == tool_count - 1
|
||||
if not is_last_tool:
|
||||
next_tool_pos = current_text.find(
|
||||
"},{", args_matches[current_idx].start())
|
||||
if next_tool_pos != -1:
|
||||
args_end_pos = (next_tool_pos + 1)
|
||||
args_text = (
|
||||
current_text[args_matches[current_idx].start(
|
||||
):args_end_pos].split('"arguments":')[1].strip())
|
||||
sent_args = self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"]
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"] and args_text.startswith("{"):
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"] = "{"
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{"
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{").model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
return delta
|
||||
|
||||
if args_text.startswith(sent_args):
|
||||
args_diff = args_text[len(sent_args):]
|
||||
if args_diff:
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"] = args_text
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += args_diff
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=args_diff).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
return delta
|
||||
|
||||
if args_text.endswith("}") and args_text == sent_args:
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"]
|
||||
return None
|
||||
216
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
216
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["internlm"])
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def get_arguments(self, obj):
|
||||
if "parameters" in obj:
|
||||
return obj.get("parameters")
|
||||
elif "arguments" in obj:
|
||||
return obj.get("arguments")
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
if '<|action_start|>' not in current_text:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
# if the tool call is sent, return an empty delta message
|
||||
# to make sure the finish_reason will be sent correctly.
|
||||
if self.current_tool_id > 0:
|
||||
return DeltaMessage(content='')
|
||||
|
||||
last_pos = self.position
|
||||
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
||||
return None
|
||||
|
||||
new_delta = current_text[last_pos:]
|
||||
text, action = new_delta.split('<|action_start|><|plugin|>')
|
||||
|
||||
if len(text) > 0:
|
||||
self.position = self.position + len(text)
|
||||
return DeltaMessage(content=text)
|
||||
|
||||
action = action.strip()
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
parsable_arr = action
|
||||
|
||||
# tool calls are generated in an object in internlm2
|
||||
# it's not support parallel tool calls
|
||||
try:
|
||||
tool_call_arr: dict = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = tool_call_arr.get("name")
|
||||
if function_name:
|
||||
self.current_tool_id = self.current_tool_id + 1
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
self.streamed_args_for_tool.append("")
|
||||
else:
|
||||
delta = None
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.get_arguments(
|
||||
self.prev_tool_call_arr[self.current_tool_id])
|
||||
cur_arguments = self.get_arguments(tool_call_arr)
|
||||
|
||||
# not arguments generated
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
# will never happen
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
# first time to get parameters
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(delta_text) +
|
||||
len(delta_text)]
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
# both prev and cur parameters, send the increase parameters
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
tool_call_arr["arguments"] = self.get_arguments(tool_call_arr)
|
||||
self.prev_tool_call_arr = [tool_call_arr]
|
||||
return delta
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
text = model_output
|
||||
tools = request.tools
|
||||
if '<|action_start|><|plugin|>' in text:
|
||||
text, action = text.split('<|action_start|><|plugin|>')
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
action = action[action.find('{'):]
|
||||
action_dict = json.loads(action)
|
||||
name, parameters = action_dict['name'], json.dumps(
|
||||
action_dict.get('parameters', action_dict.get('arguments',
|
||||
{})),
|
||||
ensure_ascii=False)
|
||||
|
||||
if not tools or name not in [t.function.name for t in tools]:
|
||||
ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(name=name, arguments=parameters))
|
||||
]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=text if len(text) > 0 else None)
|
||||
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
308
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
Normal file
308
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("jamba")
|
||||
class JambaToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Detected a MistralTokenizer tokenizer when using a Jamba model"
|
||||
)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<tool_calls>"
|
||||
self.tool_calls_end_token: str = "</tool_calls>"
|
||||
|
||||
self.tool_calls_regex = re.compile(
|
||||
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
|
||||
re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Jamba Tool parser could not locate tool calls start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because jamba use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# use a regex to find the tool call between the tags
|
||||
function_calls = self.tool_calls_regex.findall(model_output)[0]
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = json.loads(function_calls)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False),
|
||||
)) for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if
|
||||
(len(content) > 0 and content != " ") else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.tool_calls_start_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the start of tool calls token which means
|
||||
# the start of tool calling
|
||||
if (self.tool_calls_start_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion and don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# Extract the tool calls between the special tool call tokens
|
||||
parsable_arr = current_text.split(
|
||||
self.tool_calls_start_token)[-1].split(
|
||||
self.tool_calls_end_token)[0]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: list[dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff, ensure_ascii=False).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
377
vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py
Normal file
377
vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# code modified from deepseekv3_tool_parser.py
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["kimi_k2"])
|
||||
class KimiK2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
||||
self.tool_call_end_token: str = "<|tool_call_end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*?)\s*<\|tool_call_end\|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)"
|
||||
)
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<tool_call_id>.+:\d+)\s*")
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Kimi-K2 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(
|
||||
model_output)
|
||||
|
||||
logger.debug("function_call_tuples: %s", function_call_tuples)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
function_id, function_args = match
|
||||
# function_id: functions.get_weather:0
|
||||
function_name = function_id.split('.')[1].split(':')[0]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=function_id,
|
||||
type='function',
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args),
|
||||
))
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_calls_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_id, tool_args = (current_tool_call_matches.groups())
|
||||
tool_name = tool_id.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_id_str, = current_tool_call_name_matches.groups()
|
||||
tool_name = tool_id_str.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id_str
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
tool_id = current_tool_call.get("id")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
@@ -0,0 +1,316 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama4_pythonic")
|
||||
class Llama4PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Toolcall parser for Llama4 that produce tool calls in a pythonic style
|
||||
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
|
||||
"""
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
# as Llama 4 model sometime will output those tokens
|
||||
if model_output.startswith("<|python_start|>"):
|
||||
model_output = model_output[len("<|python_start|>"):]
|
||||
model_output = model_output.replace("<|python_end|>", "")
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not current_text.startswith("[") and not current_text.startswith(
|
||||
"<|python_start|>"):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
if current_text.startswith("<|python_start|>"):
|
||||
current_text = current_text[len("<|python_start|>"):]
|
||||
if current_text.endswith("<|python_end|>"):
|
||||
current_text = current_text[:current_text.
|
||||
rfind("<|python_end|>")]
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(
|
||||
tool_calls) - 1 or ")]" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = (added_text[:-2]
|
||||
if not new_call_complete else "")
|
||||
if not new_call_complete and added_text[-2] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
|
||||
new_call, index, withheld_suffix)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (delta.function is not None
|
||||
and delta.function.arguments is not None):
|
||||
self.streamed_args_for_tool[
|
||||
index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content='')
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError(
|
||||
"Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(arguments)))
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[:text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[:text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
|
||||
"[") and not text.endswith(")"):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
||||
index: int,
|
||||
withheld_suffix: str) -> Union[DeltaToolCall, None]:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[:-len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
))
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args):]
|
||||
return DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(
|
||||
arguments=arg_diff)) if arg_diff else None
|
||||
269
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
269
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama3_json")
|
||||
@ToolParserManager.register_module("llama4_json")
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Llama 3.x and 4 models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser llama3_json or
|
||||
llama4_json are set.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "<|python_tag|>"
|
||||
self.bot_token_id = tokenizer.encode(self.bot_token,
|
||||
add_special_tokens=False)[0]
|
||||
# Updated regex to match multiple JSONs separated by semicolons
|
||||
# This pattern is more robust and can handle nested JSON objects
|
||||
self.tool_call_regex = re.compile(
|
||||
r'{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*',
|
||||
re.DOTALL)
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
Only extracts JSON content and ignores any surrounding plain text.
|
||||
Supports both single JSON and multiple JSONs separated by semicolons.
|
||||
"""
|
||||
# Quick check before running regex
|
||||
if not (self.bot_token in model_output or '{' in model_output):
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
# Find JSON object(s) in the text using regex
|
||||
match = self.tool_call_regex.search(model_output)
|
||||
if not match:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
json_str = match.group(0)
|
||||
# Split by semicolon and strip whitespace
|
||||
json_objects = [obj.strip() for obj in json_str.split(';')]
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
for json_obj in json_objects:
|
||||
if not json_obj: # Skip empty strings
|
||||
continue
|
||||
obj = json.loads(json_obj)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=obj["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
obj["arguments"]
|
||||
if "arguments" in obj else obj["parameters"],
|
||||
ensure_ascii=False))))
|
||||
|
||||
return ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not (current_text.startswith(self.bot_token)
|
||||
or current_text.startswith('{')):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = len(self.bot_token) if current_text.startswith(
|
||||
self.bot_token) else 0
|
||||
while start_idx < len(current_text):
|
||||
(obj,
|
||||
end_idx) = partial_json_loads(current_text[start_idx:],
|
||||
flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx:start_idx +
|
||||
end_idx]))
|
||||
start_idx += end_idx + len('; ')
|
||||
# depending on the prompt Llama can use
|
||||
# either arguments or parameters
|
||||
if "parameters" in obj:
|
||||
assert "arguments" not in obj, \
|
||||
"model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
39
vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py
Normal file
39
vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
|
||||
Hermes2ProToolParser)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
@ToolParserManager.register_module("longcat")
|
||||
class LongcatFlashToolParser(Hermes2ProToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.tool_call_start_token: str = "<longcat_tool_call>"
|
||||
self.tool_call_end_token: str = "</longcat_tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)",
|
||||
re.DOTALL)
|
||||
|
||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_start_token, add_special_tokens=False)
|
||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_end_token, add_special_tokens=False)
|
||||
|
||||
self.tool_call_start_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_start_token_ids
|
||||
]
|
||||
|
||||
self.tool_call_end_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_end_token_ids
|
||||
]
|
||||
816
vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py
Normal file
816
vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py
Normal file
@@ -0,0 +1,816 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("minimax")
|
||||
class MinimaxToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize streaming state for tracking tool call progress
|
||||
self.streaming_state: dict[str, Any] = {
|
||||
"current_tool_index": -1, # Index of current tool being processed
|
||||
"tool_ids": [], # List of tool call IDs
|
||||
"sent_tools": [], # List of tools that have been sent
|
||||
}
|
||||
|
||||
# Define tool call tokens and patterns
|
||||
self.tool_call_start_token = "<tool_calls>"
|
||||
self.tool_call_end_token = "</tool_calls>"
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL)
|
||||
self.thinking_tag_pattern = r"<think>(.*?)</think>"
|
||||
self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"')
|
||||
self.tool_args_pattern = re.compile(r'"arguments":\s*')
|
||||
|
||||
# Buffer for handling partial tool calls during streaming
|
||||
self.pending_buffer = ""
|
||||
self.in_thinking_tag = False
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
|
||||
# Get token IDs for tool call start/end tokens
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
logger.warning(
|
||||
"Minimax Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer. Falling back to string matching.")
|
||||
|
||||
def preprocess_model_output(self, model_output: str) -> str:
|
||||
"""
|
||||
Preprocess model output by removing tool calls from thinking tags.
|
||||
|
||||
Args:
|
||||
model_output: Raw model output string
|
||||
|
||||
Returns:
|
||||
Preprocessed model output with tool calls removed from thinking tags
|
||||
"""
|
||||
|
||||
def remove_tool_calls_from_think(match):
|
||||
think_content = match.group(1)
|
||||
cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>",
|
||||
"",
|
||||
think_content,
|
||||
flags=re.DOTALL)
|
||||
return f"<think>{cleaned_content}</think>"
|
||||
|
||||
return re.sub(self.thinking_tag_pattern,
|
||||
remove_tool_calls_from_think,
|
||||
model_output,
|
||||
flags=re.DOTALL)
|
||||
|
||||
def _clean_duplicate_braces(self, args_text: str) -> str:
|
||||
"""
|
||||
Clean duplicate closing braces from arguments text.
|
||||
|
||||
Args:
|
||||
args_text: Raw arguments text
|
||||
|
||||
Returns:
|
||||
Cleaned arguments text with proper JSON formatting
|
||||
"""
|
||||
args_text = args_text.strip()
|
||||
if not args_text:
|
||||
return args_text
|
||||
|
||||
try:
|
||||
json.loads(args_text)
|
||||
return args_text
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
while args_text.endswith('}}'):
|
||||
candidate = args_text[:-1]
|
||||
try:
|
||||
json.loads(candidate)
|
||||
return candidate
|
||||
except json.JSONDecodeError:
|
||||
args_text = candidate
|
||||
|
||||
return args_text
|
||||
|
||||
def _clean_delta_braces(self, delta_text: str) -> str:
|
||||
"""
|
||||
Clean delta text by removing excessive closing braces.
|
||||
|
||||
Args:
|
||||
delta_text: Delta text to clean
|
||||
|
||||
Returns:
|
||||
Cleaned delta text
|
||||
"""
|
||||
if not delta_text:
|
||||
return delta_text
|
||||
|
||||
delta_stripped = delta_text.strip()
|
||||
|
||||
if delta_stripped and all(c in '}\n\r\t ' for c in delta_stripped):
|
||||
brace_count = delta_stripped.count('}')
|
||||
if brace_count > 1:
|
||||
return '}\n' if delta_text.endswith('\n') else '}'
|
||||
|
||||
return delta_text
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract tool calls from model output for non-streaming mode.
|
||||
|
||||
Args:
|
||||
model_output: Complete model output
|
||||
request: Chat completion request
|
||||
|
||||
Returns:
|
||||
ExtractedToolCallInformation containing tool calls and content
|
||||
"""
|
||||
processed_output = self.preprocess_model_output(model_output)
|
||||
|
||||
if self.tool_call_start_token not in processed_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
function_call_tuples = self.tool_call_regex.findall(
|
||||
processed_output)
|
||||
|
||||
raw_function_calls = []
|
||||
for match in function_call_tuples:
|
||||
tool_call_content = match[0] if match[0] else match[1]
|
||||
if tool_call_content.strip():
|
||||
lines = tool_call_content.strip().split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and line.startswith('{') and line.endswith(
|
||||
'}'):
|
||||
try:
|
||||
parsed_call = json.loads(line)
|
||||
raw_function_calls.append(parsed_call)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
tool_calls = []
|
||||
for function_call in raw_function_calls:
|
||||
if "name" in function_call and "arguments" in function_call:
|
||||
tool_calls.append(
|
||||
ToolCall(type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"],
|
||||
ensure_ascii=False))))
|
||||
|
||||
processed_pos = processed_output.find(self.tool_call_start_token)
|
||||
if processed_pos != -1:
|
||||
processed_content = processed_output[:processed_pos].strip()
|
||||
|
||||
if processed_content:
|
||||
lines = processed_content.split('\n')
|
||||
for line in reversed(lines):
|
||||
line = line.strip()
|
||||
if line:
|
||||
pos = model_output.find(line)
|
||||
if pos != -1:
|
||||
content = model_output[:pos + len(line)]
|
||||
break
|
||||
else:
|
||||
content = ""
|
||||
else:
|
||||
content = ""
|
||||
else:
|
||||
content = model_output
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=content.strip() if content.strip() else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"An unexpected error occurred during tool call extraction.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def _update_thinking_state(self, text: str) -> None:
|
||||
"""
|
||||
Update the thinking tag state based on text content.
|
||||
|
||||
Args:
|
||||
text: Text to analyze for thinking tags
|
||||
"""
|
||||
open_count = text.count("<think>")
|
||||
close_count = text.count("</think>")
|
||||
self.in_thinking_tag = open_count > close_count or (
|
||||
open_count == close_count and text.endswith("</think>"))
|
||||
|
||||
def _is_potential_tag_start(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text might be the start of a tool call tag.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text could be the start of a tool call tag
|
||||
"""
|
||||
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
|
||||
if any(
|
||||
tag.startswith(text[-i:])
|
||||
for i in range(1, min(len(text) + 1, len(tag)))):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_buffer_content(self, delta_text: str) -> bool:
|
||||
"""
|
||||
Determine if content should be buffered for later processing.
|
||||
|
||||
Args:
|
||||
delta_text: Delta text to check
|
||||
|
||||
Returns:
|
||||
True if content should be buffered
|
||||
"""
|
||||
if self.in_thinking_tag:
|
||||
return False
|
||||
return bool(self.pending_buffer
|
||||
or self.tool_call_start_token in delta_text
|
||||
or self.tool_call_end_token in delta_text
|
||||
or delta_text.startswith('<'))
|
||||
|
||||
def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Split delta text into safe content and potential tag content.
|
||||
|
||||
Args:
|
||||
delta_text: Delta text to split
|
||||
|
||||
Returns:
|
||||
Tuple of (safe_content, potential_tag_content)
|
||||
"""
|
||||
if self.in_thinking_tag:
|
||||
return delta_text, ""
|
||||
|
||||
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
|
||||
for i in range(1, len(tag)):
|
||||
tag_prefix = tag[:i]
|
||||
pos = delta_text.rfind(tag_prefix)
|
||||
if pos != -1 and tag.startswith(delta_text[pos:]):
|
||||
return delta_text[:pos], delta_text[pos:]
|
||||
return delta_text, ""
|
||||
|
||||
def _process_buffer(self, new_content: str) -> str:
|
||||
"""
|
||||
Process buffered content and return output content.
|
||||
|
||||
Args:
|
||||
new_content: New content to add to buffer
|
||||
|
||||
Returns:
|
||||
Processed output content
|
||||
"""
|
||||
self.pending_buffer += new_content
|
||||
output_content = ""
|
||||
|
||||
if self.in_thinking_tag:
|
||||
output_content = self.pending_buffer
|
||||
self.pending_buffer = ""
|
||||
return output_content
|
||||
|
||||
while self.pending_buffer:
|
||||
start_pos = self.pending_buffer.find(self.tool_call_start_token)
|
||||
end_pos = self.pending_buffer.find(self.tool_call_end_token)
|
||||
|
||||
if start_pos != -1 and (end_pos == -1 or start_pos < end_pos):
|
||||
tag_pos, tag_len = start_pos, len(self.tool_call_start_token)
|
||||
elif end_pos != -1:
|
||||
tag_pos, tag_len = end_pos, len(self.tool_call_end_token)
|
||||
else:
|
||||
if self._is_potential_tag_start(self.pending_buffer):
|
||||
break
|
||||
output_content += self.pending_buffer
|
||||
self.pending_buffer = ""
|
||||
break
|
||||
|
||||
output_content += self.pending_buffer[:tag_pos]
|
||||
self.pending_buffer = self.pending_buffer[tag_pos + tag_len:]
|
||||
|
||||
return output_content
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
"""Reset the streaming state to initial values."""
|
||||
self.streaming_state = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [],
|
||||
}
|
||||
|
||||
def _advance_to_next_tool(self) -> None:
|
||||
"""Advance to the next tool in the streaming sequence."""
|
||||
self.streaming_state["current_tool_index"] = int(
|
||||
self.streaming_state["current_tool_index"]) + 1
|
||||
|
||||
def _set_current_tool_index(self, index: int) -> None:
|
||||
"""
|
||||
Set the current tool index.
|
||||
|
||||
Args:
|
||||
index: Tool index to set
|
||||
"""
|
||||
self.streaming_state["current_tool_index"] = index
|
||||
|
||||
def _get_current_tool_index(self) -> int:
|
||||
"""
|
||||
Get the current tool index.
|
||||
|
||||
Returns:
|
||||
Current tool index
|
||||
"""
|
||||
return int(self.streaming_state["current_tool_index"])
|
||||
|
||||
def _get_next_unsent_tool_index(self, tool_count: int) -> int:
|
||||
"""
|
||||
Get the index of the next unsent tool.
|
||||
|
||||
Args:
|
||||
tool_count: Total number of tools
|
||||
|
||||
Returns:
|
||||
Index of next unsent tool, or -1 if all tools sent
|
||||
"""
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
for i in range(tool_count):
|
||||
if i < len(sent_tools):
|
||||
if not sent_tools[i]["sent_name"]:
|
||||
return i
|
||||
else:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def _ensure_state_arrays(self, tool_count: int) -> None:
|
||||
"""
|
||||
Ensure state arrays have sufficient capacity for tool_count tools.
|
||||
|
||||
Args:
|
||||
tool_count: Number of tools to prepare for
|
||||
"""
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
tool_ids = list(self.streaming_state["tool_ids"])
|
||||
|
||||
while len(sent_tools) < tool_count:
|
||||
sent_tools.append({
|
||||
"sent_name": False,
|
||||
"sent_arguments": "",
|
||||
"id": make_tool_call_id(),
|
||||
})
|
||||
|
||||
while len(tool_ids) < tool_count:
|
||||
tool_ids.append(None)
|
||||
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
self.streaming_state["tool_ids"] = tool_ids
|
||||
|
||||
def _detect_tools_in_text(self, text: str) -> int:
|
||||
"""
|
||||
Detect the number of tools in text by counting name patterns.
|
||||
|
||||
Args:
|
||||
text: Text to analyze
|
||||
|
||||
Returns:
|
||||
Number of tools detected
|
||||
"""
|
||||
matches = self.tool_name_pattern.findall(text)
|
||||
return len(matches)
|
||||
|
||||
def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Find the boundaries of tool calls in text.
|
||||
|
||||
Args:
|
||||
text: Text to analyze
|
||||
|
||||
Returns:
|
||||
List of (start, end) positions for tool calls
|
||||
"""
|
||||
boundaries = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
if text[i] == '{':
|
||||
start = i
|
||||
depth = 0
|
||||
has_name = False
|
||||
has_arguments = False
|
||||
|
||||
while i < len(text):
|
||||
if text[i] == '{':
|
||||
depth += 1
|
||||
elif text[i] == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end = i + 1
|
||||
segment = text[start:end]
|
||||
if '"name"' in segment and '"arguments"' in segment:
|
||||
boundaries.append((start, end))
|
||||
break
|
||||
|
||||
if not has_name and '"name"' in text[start:i + 1]:
|
||||
has_name = True
|
||||
if not has_arguments and '"arguments"' in text[start:i +
|
||||
1]:
|
||||
has_arguments = True
|
||||
|
||||
i += 1
|
||||
|
||||
if depth > 0 and has_name:
|
||||
boundaries.append((start, i))
|
||||
else:
|
||||
i += 1
|
||||
return boundaries
|
||||
|
||||
def _extract_tool_args(self, tool_content: str,
|
||||
args_match: re.Match[str]) -> str:
|
||||
"""
|
||||
Extract tool arguments from tool content.
|
||||
|
||||
Args:
|
||||
tool_content: Tool call content
|
||||
args_match: Regex match for arguments pattern
|
||||
|
||||
Returns:
|
||||
Extracted arguments as string
|
||||
"""
|
||||
args_start_pos = args_match.end()
|
||||
remaining_content = tool_content[args_start_pos:]
|
||||
|
||||
if remaining_content.strip().startswith('{'):
|
||||
depth = 0
|
||||
for i, char in enumerate(remaining_content):
|
||||
if char == '{':
|
||||
depth += 1
|
||||
elif char == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return remaining_content[:i + 1]
|
||||
else:
|
||||
args_end = remaining_content.find('}')
|
||||
if args_end > 0:
|
||||
return remaining_content[:args_end].strip()
|
||||
|
||||
return remaining_content.rstrip('}').strip()
|
||||
|
||||
def _get_current_tool_content(
|
||||
self, text: str,
|
||||
tool_index: int) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Get the content of a specific tool by index.
|
||||
|
||||
Args:
|
||||
text: Text containing tool calls
|
||||
tool_index: Index of tool to extract
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_name, tool_arguments) or (None, None) if not found
|
||||
"""
|
||||
boundaries = self._find_tool_boundaries(text)
|
||||
|
||||
if tool_index >= len(boundaries):
|
||||
return None, None
|
||||
|
||||
start, end = boundaries[tool_index]
|
||||
tool_content = text[start:end]
|
||||
|
||||
name_match = self.tool_name_pattern.search(tool_content)
|
||||
name = name_match.group(1) if name_match else None
|
||||
|
||||
args_match = self.tool_args_pattern.search(tool_content)
|
||||
if args_match:
|
||||
try:
|
||||
args_text = self._extract_tool_args(tool_content, args_match)
|
||||
return name, args_text
|
||||
except Exception:
|
||||
remaining_content = tool_content[args_match.end():]
|
||||
args_text = remaining_content.rstrip('}').strip()
|
||||
return name, args_text
|
||||
|
||||
return name, None
|
||||
|
||||
def _handle_tool_name_streaming(
|
||||
self, tool_content: str,
|
||||
tool_count: int) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Handle streaming of tool names.
|
||||
|
||||
Args:
|
||||
tool_content: Content containing tool calls
|
||||
tool_count: Total number of tools
|
||||
|
||||
Returns:
|
||||
DeltaMessage with tool name or None if no tool to stream
|
||||
"""
|
||||
next_idx = self._get_next_unsent_tool_index(tool_count)
|
||||
|
||||
if next_idx == -1:
|
||||
return None
|
||||
|
||||
boundaries = self._find_tool_boundaries(tool_content)
|
||||
if next_idx >= len(boundaries):
|
||||
return None
|
||||
|
||||
tool_name, _ = self._get_current_tool_content(tool_content, next_idx)
|
||||
if not tool_name:
|
||||
return None
|
||||
|
||||
self._set_current_tool_index(next_idx)
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
tool_ids = list(self.streaming_state["tool_ids"])
|
||||
|
||||
tool_id = sent_tools[next_idx]["id"]
|
||||
tool_ids[next_idx] = tool_id
|
||||
sent_tools[next_idx]["sent_name"] = True
|
||||
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
self.streaming_state["tool_ids"] = tool_ids
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=next_idx,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_name).model_dump(exclude_none=True))
|
||||
])
|
||||
|
||||
def _handle_tool_args_streaming(
|
||||
self, tool_content: str,
|
||||
tool_count: int) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Handle streaming of tool arguments.
|
||||
|
||||
Args:
|
||||
tool_content: Content containing tool calls
|
||||
tool_count: Total number of tools
|
||||
|
||||
Returns:
|
||||
DeltaMessage with tool arguments or None if no arguments to stream
|
||||
"""
|
||||
current_idx = self._get_current_tool_index()
|
||||
|
||||
if current_idx < 0 or current_idx >= tool_count:
|
||||
return None
|
||||
|
||||
tool_name, tool_args = self._get_current_tool_content(
|
||||
tool_content, current_idx)
|
||||
if not tool_name or tool_args is None:
|
||||
return None
|
||||
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
|
||||
if not sent_tools[current_idx]["sent_name"]:
|
||||
return None
|
||||
|
||||
clean_args = self._clean_duplicate_braces(tool_args)
|
||||
sent_args = sent_tools[current_idx]["sent_arguments"]
|
||||
|
||||
if clean_args != sent_args:
|
||||
if sent_args and clean_args.startswith(sent_args):
|
||||
args_delta = extract_intermediate_diff(clean_args, sent_args)
|
||||
if args_delta:
|
||||
args_delta = self._clean_delta_braces(args_delta)
|
||||
sent_tools[current_idx]["sent_arguments"] = clean_args
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
|
||||
if clean_args.endswith('}'):
|
||||
self._advance_to_next_tool()
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=args_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
elif not sent_args and clean_args:
|
||||
clean_args_delta = self._clean_delta_braces(clean_args)
|
||||
sent_tools[current_idx]["sent_arguments"] = clean_args
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
|
||||
if clean_args.endswith('}'):
|
||||
self._advance_to_next_tool()
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=clean_args_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
|
||||
return None
|
||||
|
||||
def _is_end_tool_calls(self, current_text: str) -> bool:
|
||||
if self.tool_call_end_token not in current_text:
|
||||
return False
|
||||
|
||||
end_token_positions = []
|
||||
search_start = 0
|
||||
while True:
|
||||
pos = current_text.find(self.tool_call_end_token, search_start)
|
||||
if pos == -1:
|
||||
break
|
||||
end_token_positions.append(pos)
|
||||
search_start = pos + 1
|
||||
|
||||
think_regions = []
|
||||
for match in re.finditer(self.thinking_tag_pattern,
|
||||
current_text,
|
||||
flags=re.DOTALL):
|
||||
think_regions.append((match.start(), match.end()))
|
||||
|
||||
for pos in end_token_positions:
|
||||
in_think = any(pos >= t_start and pos < t_end
|
||||
for t_start, t_end in think_regions)
|
||||
if not in_think:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
self._update_thinking_state(current_text)
|
||||
|
||||
if self.in_thinking_tag:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self._should_buffer_content(delta_text):
|
||||
buffered_output = self._process_buffer(delta_text)
|
||||
return DeltaMessage(
|
||||
content=buffered_output) if buffered_output else None
|
||||
|
||||
if self._is_end_tool_calls(current_text):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
safe_content, potential_tag = self._split_content_for_buffering(
|
||||
delta_text)
|
||||
if potential_tag:
|
||||
self.pending_buffer += potential_tag
|
||||
return DeltaMessage(content=safe_content) if safe_content else None
|
||||
|
||||
processed_current_text = self.preprocess_model_output(current_text)
|
||||
|
||||
if self.tool_call_start_token not in processed_current_text:
|
||||
if (self.tool_call_end_token in delta_text
|
||||
and self.tool_call_start_token in current_text):
|
||||
return None
|
||||
if delta_text.strip(
|
||||
) == '' and self.tool_call_start_token in current_text:
|
||||
return None
|
||||
if (self._get_current_tool_index() != -1
|
||||
and self.tool_call_end_token in current_text):
|
||||
self._reset_streaming_state()
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if (self.tool_call_start_token_id is not None
|
||||
and self.tool_call_start_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
return None
|
||||
|
||||
original_tool_start = self._find_tool_start_outside_thinking(
|
||||
current_text)
|
||||
if original_tool_start is None:
|
||||
return None
|
||||
|
||||
content_before_tools = self._extract_content_before_tools(
|
||||
current_text, delta_text, original_tool_start)
|
||||
if content_before_tools:
|
||||
return DeltaMessage(content=content_before_tools)
|
||||
|
||||
try:
|
||||
tool_content = self._extract_tool_content(current_text,
|
||||
original_tool_start)
|
||||
current_tools_count = self._detect_tools_in_text(tool_content)
|
||||
|
||||
if current_tools_count == 0:
|
||||
return None
|
||||
|
||||
if self._get_current_tool_index() == -1:
|
||||
self._reset_streaming_state()
|
||||
|
||||
self._ensure_state_arrays(current_tools_count)
|
||||
|
||||
return (self._handle_tool_name_streaming(tool_content,
|
||||
current_tools_count)
|
||||
or self._handle_tool_args_streaming(
|
||||
tool_content, current_tools_count))
|
||||
|
||||
except Exception:
|
||||
logger.exception("An unexpected error occurred ",
|
||||
"during streaming tool call handling.")
|
||||
return None
|
||||
|
||||
def _find_tool_start_outside_thinking(self,
|
||||
current_text: str) -> Optional[int]:
|
||||
"""
|
||||
Find the start position of tool calls outside of thinking tags.
|
||||
|
||||
Args:
|
||||
current_text: Current text to search
|
||||
|
||||
Returns:
|
||||
Position of tool call start or None if not found
|
||||
"""
|
||||
search_start = 0
|
||||
while True:
|
||||
pos = current_text.find(self.tool_call_start_token, search_start)
|
||||
if pos == -1:
|
||||
return None
|
||||
|
||||
think_regions = [(m.start(), m.end()) for m in re.finditer(
|
||||
r"<think>(.*?)</think>", current_text, flags=re.DOTALL)]
|
||||
in_think = any(pos >= t_start and pos < t_end
|
||||
for t_start, t_end in think_regions)
|
||||
|
||||
if not in_think:
|
||||
return pos
|
||||
|
||||
search_start = pos + 1
|
||||
|
||||
def _extract_content_before_tools(self, current_text: str, delta_text: str,
|
||||
tool_start: int) -> Optional[str]:
|
||||
"""
|
||||
Extract content that appears before tool calls.
|
||||
|
||||
Args:
|
||||
current_text: Current text
|
||||
delta_text: Delta text
|
||||
tool_start: Start position of tools
|
||||
|
||||
Returns:
|
||||
Content before tools or None
|
||||
"""
|
||||
if tool_start > 0:
|
||||
delta_start_pos = len(current_text) - len(delta_text)
|
||||
if delta_start_pos < tool_start:
|
||||
content_part = delta_text
|
||||
if delta_start_pos + len(delta_text) > tool_start:
|
||||
content_part = delta_text[:tool_start - delta_start_pos]
|
||||
return content_part if content_part else None
|
||||
return None
|
||||
|
||||
def _extract_tool_content(self, current_text: str, tool_start: int) -> str:
|
||||
"""
|
||||
Extract tool content from current text starting at tool_start.
|
||||
|
||||
Args:
|
||||
current_text: Current text
|
||||
tool_start: Start position of tool calls
|
||||
|
||||
Returns:
|
||||
Extracted tool content
|
||||
"""
|
||||
tool_content_start = tool_start + len(self.tool_call_start_token)
|
||||
tool_content = current_text[tool_content_start:]
|
||||
|
||||
end_pos = tool_content.find(self.tool_call_end_token)
|
||||
if end_pos != -1:
|
||||
tool_content = tool_content[:end_pos]
|
||||
|
||||
return tool_content
|
||||
369
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
369
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
@@ -0,0 +1,369 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from random import choices
|
||||
from string import ascii_letters, digits
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
|
||||
class MistralToolCall(ToolCall):
|
||||
id: str = Field(
|
||||
default_factory=lambda: MistralToolCall.generate_random_id())
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
@staticmethod
|
||||
def is_valid_id(id: str) -> bool:
|
||||
return id.isalnum() and len(id) == 9
|
||||
|
||||
|
||||
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
|
||||
return isinstance(model_tokenizer, MistralTokenizer) \
|
||||
and model_tokenizer.version >= 11
|
||||
|
||||
|
||||
@ToolParserManager.register_module("mistral")
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
|
||||
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
|
||||
- the examples/tool_chat_template_mistral.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||
"model...")
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(
|
||||
r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
|
||||
else:
|
||||
self.fn_name_regex = None
|
||||
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
"the tokenizer!")
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if not isinstance(
|
||||
self.model_tokenizer, MistralTokenizer
|
||||
) and request.tools and request.tool_choice != 'none':
|
||||
# Do not skip special tokens when using chat template
|
||||
# with Mistral parser as TOOL_CALL token is needed
|
||||
# for tool detection.
|
||||
# Note: we don't want skip_special_tokens=False
|
||||
# with MistralTokenizer as it is incompatible
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
make sure your tool call arguments don't ever include quotes!
|
||||
"""
|
||||
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if self.bot_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
# first remove the BOT token
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
|
||||
try:
|
||||
# we first try to directly load the json as parsing very nested
|
||||
# jsons is difficult
|
||||
try:
|
||||
if self.fn_name_regex:
|
||||
matches = self.fn_name_regex.findall(tool_content)
|
||||
|
||||
function_call_arr = []
|
||||
for match in matches:
|
||||
fn_name = match[0]
|
||||
args = match[1]
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append({
|
||||
"name": fn_name,
|
||||
"arguments": json.loads(args)
|
||||
})
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
# use a regex to find the part corresponding to the tool call.
|
||||
# NOTE: This use case should not happen if the model is trained
|
||||
# correctly. It's an easy possible fix so it's included, but
|
||||
# can be brittle for very complex / highly nested tool calls
|
||||
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
|
||||
# Tool Call
|
||||
tool_calls: list[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"],
|
||||
ensure_ascii=False)))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
content = model_output.split(self.bot_token)[0]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if len(content) > 0 else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=tool_content)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.bot_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the BOT token which means the start of tool
|
||||
# calling
|
||||
if (self.bot_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion any don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# replace BOT token with empty string, and convert single quotes
|
||||
# to double to allow parsing as JSON since mistral uses single
|
||||
# quotes instead of double for tool calls
|
||||
parsable_arr = current_text.split(self.bot_token)[-1]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: list[dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff, ensure_ascii=False).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=MistralToolCall.generate_random_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
if ('"}' in new_text):
|
||||
new_text = new_text[:new_text.rindex('"}')]
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)[:-2]
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
if (new_text not in cur_arguments_json):
|
||||
return None
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
rindex(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
93
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
Normal file
93
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.harmony_utils import parse_output_into_messages
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("openai")
|
||||
class OpenAIToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
token_ids: Sequence[int] | None = None,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if token_ids is None:
|
||||
raise NotImplementedError(
|
||||
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
|
||||
)
|
||||
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
tool_calls = []
|
||||
final_content = None
|
||||
|
||||
if len(parser.messages) > 0:
|
||||
for msg in parser.messages:
|
||||
if len(msg.content) < 1:
|
||||
continue
|
||||
msg_text = msg.content[0].text
|
||||
if msg.recipient and msg.recipient.startswith("functions."):
|
||||
# If no content-type is given assume JSON, as that's the
|
||||
# most common case with gpt-oss models.
|
||||
if not msg.content_type or "json" in msg.content_type:
|
||||
# load and dump the JSON text to check validity and
|
||||
# remove any extra newlines or other odd formatting
|
||||
try:
|
||||
tool_args = json.dumps(json.loads(msg_text))
|
||||
except json.JSONDecodeError:
|
||||
logger.exception(
|
||||
"Error decoding JSON tool call from response.")
|
||||
tool_args = msg_text
|
||||
else:
|
||||
tool_args = msg_text
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=msg.recipient.split("functions.")[1],
|
||||
arguments=tool_args,
|
||||
),
|
||||
))
|
||||
elif msg.channel == "final":
|
||||
final_content = msg_text
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=final_content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
raise NotImplementedError(
|
||||
"Not being used, manual parsing in serving_chat.py" # noqa: E501
|
||||
)
|
||||
112
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
Normal file
112
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("phi4_mini_json")
|
||||
class Phi4MiniJsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for phi-4-mini models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token: str = "functools"
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
logger.debug("Model output: %s", model_output)
|
||||
|
||||
pattern = r'functools\[(.*?)\]'
|
||||
matches = re.search(pattern, model_output, re.DOTALL)
|
||||
|
||||
if not matches:
|
||||
logger.debug("No function calls found")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
function_call_arr: list[dict[str, Any]] = []
|
||||
try:
|
||||
json_content = '[' + matches.group(1) + ']'
|
||||
|
||||
function_call_arr = json.loads(json_content)
|
||||
logger.debug("Successfully extracted %d function calls",
|
||||
len(function_call_arr))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to parse function calls from model output. "
|
||||
"Error: %s", str(e))
|
||||
|
||||
tool_calls: list[ToolCall] = [
|
||||
ToolCall(
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
raw_function_call["arguments"]
|
||||
if "arguments" in raw_function_call else
|
||||
raw_function_call["parameters"],
|
||||
ensure_ascii=False),
|
||||
)) for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
ret = ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None)
|
||||
return ret
|
||||
|
||||
except Exception:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Optional[DeltaMessage]:
|
||||
|
||||
return None
|
||||
308
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
Normal file
308
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ToolParserManager.register_module("pythonic")
|
||||
class PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for models that produce tool calls in a pythonic style,
|
||||
such as Llama 3.2 and Llama 4 models.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
|
||||
"""
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not current_text.startswith("["):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(
|
||||
tool_calls) - 1 or ")]" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = (added_text[:-2]
|
||||
if not new_call_complete else "")
|
||||
if not new_call_complete and added_text[-2] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
|
||||
new_call, index, withheld_suffix)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (delta.function is not None
|
||||
and delta.function.arguments is not None):
|
||||
self.streamed_args_for_tool[
|
||||
index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content='')
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError(
|
||||
"Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(arguments,
|
||||
ensure_ascii=False)),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[:text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[:text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
|
||||
"[") and not text.endswith(")"):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
||||
index: int,
|
||||
withheld_suffix: str) -> Union[DeltaToolCall, None]:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[:-len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
))
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args):]
|
||||
return DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(
|
||||
arguments=arg_diff)) if arg_diff else None
|
||||
707
vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
Normal file
707
vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
Normal file
@@ -0,0 +1,707 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("qwen3_coder")
|
||||
class Qwen3CoderToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
# Override base class type - we use string IDs for tool calls
|
||||
self.current_tool_id: Optional[str] = None # type: ignore
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
# Sentinel tokens for streaming mode
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.tool_call_prefix: str = "<function="
|
||||
self.function_end_token: str = "</function>"
|
||||
self.parameter_prefix: str = "<parameter="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
self.is_tool_call_started: bool = False
|
||||
self.failed_count: int = 0
|
||||
|
||||
# Enhanced streaming state - reset for each new message
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Regex patterns
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL)
|
||||
self.tool_call_function_regex = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
|
||||
self.tool_call_parameter_regex = re.compile(
|
||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||
re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Qwen3 XML Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
logger.info("vLLM Successfully import tool parser %s !",
|
||||
self.__class__.__name__)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = None
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
# Store accumulated parameters for type conversion
|
||||
self.accumulated_params = {}
|
||||
self.streaming_request = None
|
||||
|
||||
def _get_arguments_config(
|
||||
self, func_name: str,
|
||||
tools: Optional[list[ChatCompletionToolsParam]]) -> dict:
|
||||
"""Extract argument configuration for a function."""
|
||||
if tools is None:
|
||||
return {}
|
||||
for config in tools:
|
||||
if not hasattr(config, "type") or not (hasattr(
|
||||
config, "function") and hasattr(config.function, "name")):
|
||||
continue
|
||||
if config.type == "function" and config.function.name == func_name:
|
||||
if not hasattr(config.function, "parameters"):
|
||||
return {}
|
||||
params = config.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
return params["properties"]
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
return {}
|
||||
logger.warning("Tool '%s' is not defined in the tools list.",
|
||||
func_name)
|
||||
return {}
|
||||
|
||||
def _convert_param_value(self, param_value: str, param_name: str,
|
||||
param_config: dict, func_name: str) -> Any:
|
||||
"""Convert parameter value based on its type in the schema."""
|
||||
# Handle null value for any type
|
||||
if param_value.lower() == "null":
|
||||
return None
|
||||
|
||||
if param_name not in param_config:
|
||||
if param_config != {}:
|
||||
logger.warning(
|
||||
"Parsed parameter '%s' is not defined in the tool "
|
||||
"parameters for tool '%s', directly returning the "
|
||||
"string value.", param_name, func_name)
|
||||
return param_value
|
||||
|
||||
if isinstance(param_config[param_name],
|
||||
dict) and "type" in param_config[param_name]:
|
||||
param_type = str(param_config[param_name]["type"]).strip().lower()
|
||||
else:
|
||||
param_type = "string"
|
||||
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
||||
return param_value
|
||||
elif param_type.startswith("int") or param_type.startswith(
|
||||
"uint") or param_type.startswith(
|
||||
"long") or param_type.startswith(
|
||||
"short") or param_type.startswith("unsigned"):
|
||||
try:
|
||||
return int(param_value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not an "
|
||||
"integer in tool '%s', degenerating to string.",
|
||||
param_value, param_name, func_name)
|
||||
return param_value
|
||||
elif param_type.startswith("num") or param_type.startswith("float"):
|
||||
try:
|
||||
float_param_value = float(param_value)
|
||||
return float_param_value if float_param_value - int(
|
||||
float_param_value) != 0 else int(float_param_value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a float "
|
||||
"in tool '%s', degenerating to string.", param_value,
|
||||
param_name, func_name)
|
||||
return param_value
|
||||
elif param_type in ["boolean", "bool", "binary"]:
|
||||
param_value = param_value.lower()
|
||||
if param_value not in ["true", "false"]:
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a boolean "
|
||||
"(`true` or `false`) in tool '%s', degenerating to "
|
||||
"false.", param_value, param_name, func_name)
|
||||
return param_value == "true"
|
||||
else:
|
||||
if param_type in ["object", "array", "arr"
|
||||
] or param_type.startswith(
|
||||
"dict") or param_type.startswith("list"):
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
return param_value
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' cannot be "
|
||||
"parsed with json.loads in tool '%s', will try "
|
||||
"other methods to parse it.", param_value, param_name,
|
||||
func_name)
|
||||
try:
|
||||
param_value = ast.literal_eval(param_value) # safer
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' cannot be "
|
||||
"converted via Python `ast.literal_eval()` in tool "
|
||||
"'%s', degenerating to string.", param_value, param_name,
|
||||
func_name)
|
||||
return param_value
|
||||
|
||||
def _parse_xml_function_call(
|
||||
self, function_call_str: str,
|
||||
tools: Optional[list[ChatCompletionToolsParam]]
|
||||
) -> Optional[ToolCall]:
|
||||
|
||||
# Extract function name
|
||||
end_index = function_call_str.index(">")
|
||||
function_name = function_call_str[:end_index]
|
||||
param_config = self._get_arguments_config(function_name, tools)
|
||||
parameters = function_call_str[end_index + 1:]
|
||||
param_dict = {}
|
||||
for match_text in self.tool_call_parameter_regex.findall(parameters):
|
||||
idx = match_text.index(">")
|
||||
param_name = match_text[:idx]
|
||||
param_value = str(match_text[idx + 1:])
|
||||
# Remove prefix and trailing \n
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = self._convert_param_value(
|
||||
param_value, param_name, param_config, function_name)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(param_dict,
|
||||
ensure_ascii=False)),
|
||||
)
|
||||
|
||||
def _get_function_calls(self, model_output: str) -> list[str]:
|
||||
# Find all tool calls
|
||||
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||
raw_tool_calls = [
|
||||
match[0] if match[0] else match[1] for match in matched_ranges
|
||||
]
|
||||
|
||||
# Back-off strategy if no tool_call tags found
|
||||
if len(raw_tool_calls) == 0:
|
||||
raw_tool_calls = [model_output]
|
||||
|
||||
raw_function_calls = []
|
||||
for tool_call in raw_tool_calls:
|
||||
raw_function_calls.extend(
|
||||
self.tool_call_function_regex.findall(tool_call))
|
||||
|
||||
function_calls = [
|
||||
match[0] if match[0] else match[1] for match in raw_function_calls
|
||||
]
|
||||
return function_calls
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# Quick check to avoid unnecessary processing
|
||||
if self.tool_call_prefix not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
function_calls = self._get_function_calls(model_output)
|
||||
if len(function_calls) == 0:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
tool_calls = [
|
||||
self._parse_xml_function_call(function_call_str, request.tools)
|
||||
for function_call_str in function_calls
|
||||
]
|
||||
|
||||
# Populate prev_tool_call_arr for serving layer to set finish_reason
|
||||
self.prev_tool_call_arr.clear() # Clear previous calls
|
||||
for tool_call in tool_calls:
|
||||
if tool_call:
|
||||
self.prev_tool_call_arr.append({
|
||||
"name":
|
||||
tool_call.function.name,
|
||||
"arguments":
|
||||
tool_call.function.arguments,
|
||||
})
|
||||
|
||||
# Extract content before tool calls
|
||||
content_index = model_output.find(self.tool_call_start_token)
|
||||
idx = model_output.find(self.tool_call_prefix)
|
||||
content_index = content_index if content_index >= 0 else idx
|
||||
content = model_output[:content_index] # .rstrip()
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=(len(tool_calls) > 0),
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
# Store request for type conversion
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
self.streaming_request = request
|
||||
|
||||
# If no delta text, return None unless it's an EOS token after tools
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
# Check for tool calls in text even if is_tool_call_started
|
||||
# is False (might have been reset after processing all tools)
|
||||
if (delta_token_ids
|
||||
and self.tool_call_end_token_id not in delta_token_ids):
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text))
|
||||
|
||||
# If we have completed tool calls and populated
|
||||
# prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token) - current_text.count(
|
||||
self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta for finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Update accumulated text
|
||||
self.accumulated_text = current_text
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
tool_ends = current_text.count(self.tool_call_end_token)
|
||||
if tool_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
self.accumulated_params = {}
|
||||
|
||||
# Check if there are more tool calls
|
||||
tool_starts = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts:
|
||||
# No more tool calls
|
||||
self.is_tool_call_started = False
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if (self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text):
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[:delta_text.index(
|
||||
self.tool_call_start_token)]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
# Count tool calls we've seen vs processed
|
||||
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# We're in a tool call, find the current tool call portion
|
||||
# Need to find the correct tool call based on current_tool_index
|
||||
tool_start_positions: list[int] = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = current_text.find(self.tool_call_start_token, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
tool_start_positions.append(idx)
|
||||
idx += len(self.tool_call_start_token)
|
||||
|
||||
if self.current_tool_index >= len(tool_start_positions):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
tool_start_idx = tool_start_positions[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
tool_end_idx = current_text.find(self.tool_call_end_token,
|
||||
tool_start_idx)
|
||||
if tool_end_idx == -1:
|
||||
tool_text = current_text[tool_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[tool_start_idx:tool_end_idx +
|
||||
len(self.tool_call_end_token)]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.tool_call_prefix in tool_text:
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix)
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
self.current_function_name = tool_text[func_start:func_end]
|
||||
self.current_tool_id = self._generate_tool_call_id()
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# IMPORTANT: Add to prev_tool_call_arr immediately when
|
||||
# we detect a tool call. This ensures
|
||||
# finish_reason="tool_calls" even if parsing isn't complete
|
||||
already_added = any(
|
||||
tool.get("name") == self.current_function_name
|
||||
for tool in self.prev_tool_call_arr)
|
||||
if not already_added:
|
||||
self.prev_tool_call_arr.append({
|
||||
"name": self.current_function_name,
|
||||
"arguments":
|
||||
"{}", # Placeholder, will be updated later
|
||||
})
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""),
|
||||
type="function",
|
||||
)
|
||||
])
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if (not self.json_started
|
||||
and self.parameter_prefix not in delta_text):
|
||||
self.json_started = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
])
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.function_end_token in tool_text:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract complete tool call to update
|
||||
# prev_tool_call_arr with final arguments
|
||||
# Find the function content
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix)
|
||||
func_content_end = tool_text.find(self.function_end_token,
|
||||
func_start)
|
||||
if func_content_end != -1:
|
||||
func_content = tool_text[func_start:func_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
parsed_tool = self._parse_xml_function_call(
|
||||
func_content, self.streaming_request.tools
|
||||
if self.streaming_request else None)
|
||||
if parsed_tool:
|
||||
# Update existing entry in
|
||||
# prev_tool_call_arr with complete args
|
||||
for i, tool in enumerate(self.prev_tool_call_arr):
|
||||
if tool.get(
|
||||
"name") == parsed_tool.function.name:
|
||||
args = parsed_tool.function.arguments
|
||||
self.prev_tool_call_arr[i][
|
||||
"arguments"] = args
|
||||
break
|
||||
except Exception:
|
||||
pass # Ignore parsing errors during streaming
|
||||
|
||||
result = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
])
|
||||
|
||||
# Reset state for next tool
|
||||
self.in_function = False
|
||||
self.json_closed = True
|
||||
self.accumulated_params = {}
|
||||
|
||||
return result
|
||||
|
||||
# Look for parameters
|
||||
# Find all parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if (not self.in_param and self.param_count < len(param_starts)
|
||||
and len(param_starts) > self.param_count):
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
self.current_param_name = remaining[:name_end]
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(self.parameter_end_token)
|
||||
if param_end_idx == -1:
|
||||
# No closing tag, look for next parameter or
|
||||
# function end
|
||||
next_param_idx = value_text.find(self.parameter_prefix)
|
||||
func_end_idx = value_text.find(self.function_end_token)
|
||||
|
||||
if next_param_idx != -1 and (func_end_idx == -1
|
||||
or next_param_idx
|
||||
< func_end_idx):
|
||||
param_end_idx = next_param_idx
|
||||
elif func_end_idx != -1:
|
||||
param_end_idx = func_end_idx
|
||||
else:
|
||||
# Neither found, check if tool call is complete
|
||||
if self.tool_call_end_token in tool_text:
|
||||
# Tool call is complete, so parameter
|
||||
# must be complete too. Use all
|
||||
# remaining text before function end
|
||||
param_end_idx = len(value_text)
|
||||
else:
|
||||
# Still streaming, wait for more content
|
||||
return None
|
||||
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Store raw value for later processing
|
||||
self.accumulated_params[
|
||||
self.current_param_name] = param_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
param_config = self._get_arguments_config(
|
||||
self.current_function_name or "",
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request else None)
|
||||
|
||||
# Convert param value to appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
param_value, self.current_param_name, param_config,
|
||||
self.current_function_name or "")
|
||||
|
||||
# Build JSON fragment based on the converted type
|
||||
# Use json.dumps to properly serialize the value
|
||||
serialized_value = json.dumps(converted_value,
|
||||
ensure_ascii=False)
|
||||
|
||||
if self.param_count == 0:
|
||||
json_fragment = (f'"{self.current_param_name}": '
|
||||
f'{serialized_value}')
|
||||
else:
|
||||
json_fragment = (f', "{self.current_param_name}": '
|
||||
f'{serialized_value}')
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=json_fragment),
|
||||
)
|
||||
])
|
||||
|
||||
# Continue parameter value - Not used in the current implementation
|
||||
# since we process complete parameters above
|
||||
if self.in_param:
|
||||
if self.parameter_end_token in delta_text:
|
||||
# End of parameter
|
||||
end_idx = delta_text.find(self.parameter_end_token)
|
||||
value_chunk = delta_text[:end_idx]
|
||||
|
||||
# Skip past > if at start
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1:]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith(
|
||||
"\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
# Store complete value
|
||||
full_value = self.current_param_value + value_chunk
|
||||
self.accumulated_params[
|
||||
self.current_param_name] = full_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
param_config = self._get_arguments_config(
|
||||
self.current_function_name or "",
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request else None)
|
||||
|
||||
# Convert the parameter value to the appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
full_value, self.current_param_name or "",
|
||||
param_config, self.current_function_name or "")
|
||||
|
||||
# Serialize the converted value
|
||||
serialized_value = json.dumps(converted_value,
|
||||
ensure_ascii=False)
|
||||
|
||||
# Since we've been streaming the quoted version,
|
||||
# we need to close it properly
|
||||
# This is complex - for now just complete the value
|
||||
self.in_param = False
|
||||
self.current_param_value = ""
|
||||
|
||||
# Just close the current parameter string
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments='"'), # Close the string quote
|
||||
)
|
||||
])
|
||||
else:
|
||||
# Continue accumulating value
|
||||
value_chunk = delta_text
|
||||
|
||||
# Handle first chunk after param name
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1:]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith(
|
||||
"\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
if value_chunk:
|
||||
# Stream the escaped delta
|
||||
prev_escaped = json.dumps(
|
||||
self.current_param_value, ensure_ascii=False
|
||||
)[1:-1] if self.current_param_value else ""
|
||||
self.current_param_value += value_chunk
|
||||
full_escaped = json.dumps(self.current_param_value,
|
||||
ensure_ascii=False)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped):]
|
||||
|
||||
if delta_escaped:
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped),
|
||||
)
|
||||
])
|
||||
|
||||
return None
|
||||
1137
vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py
Normal file
1137
vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py
Normal file
File diff suppressed because it is too large
Load Diff
679
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
Normal file
679
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
Normal file
@@ -0,0 +1,679 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from qwen3coder xml parser, All rights reserved.
|
||||
# ruff: noqa: E501
|
||||
|
||||
import ast
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("seed_oss")
|
||||
class SeedOssToolParser(ToolParser):
|
||||
TOOL_CALL_START = "<seed:tool_call>"
|
||||
TOOL_CALL_END = "</seed:tool_call>"
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# --- streaming state ---
|
||||
self._reset_streaming_state()
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
self.tool_call_start_token: str = self.TOOL_CALL_START
|
||||
self.tool_call_end_token: str = self.TOOL_CALL_END
|
||||
# Sentinel tokens for streaming mode
|
||||
self.tool_call_prefix: str = "<function="
|
||||
self.function_end_token: str = "</function>"
|
||||
self.parameter_prefix: str = "<parameter="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
self.think_start_token: str = "<seed:think>"
|
||||
self.think_end_token: str = "</seed:think>"
|
||||
self.is_tool_call_started: bool = False
|
||||
self.is_thinking_end: bool = False
|
||||
self.failed_count: int = 0
|
||||
self._reset_streaming_state()
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Seed_Oss XML parser: tokenizer did not include "
|
||||
"<seed:tool_call> or its closing tag.")
|
||||
|
||||
tool_start_re = re.escape(self.tool_call_start_token)
|
||||
tool_end_re = re.escape(self.tool_call_end_token)
|
||||
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL)
|
||||
self.tool_call_regex = re.compile(
|
||||
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$",
|
||||
re.DOTALL)
|
||||
|
||||
self.tool_call_function_regex = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
|
||||
self.tool_call_parameter_regex = re.compile(
|
||||
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
|
||||
|
||||
logger.info("vLLM Seed-Oss XML tool parser loaded (%s).",
|
||||
self.__class__.__name__)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = -1
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
|
||||
def _parse_xml_function_call(
|
||||
self, function_call_str: str,
|
||||
tools: Optional[list[ChatCompletionToolsParam]]
|
||||
) -> Optional[ToolCall]:
|
||||
|
||||
def get_arguments_config(func_name: str) -> dict:
|
||||
if tools is None:
|
||||
return {}
|
||||
for config in tools:
|
||||
if not hasattr(config, "type") or not (
|
||||
hasattr(config, "function")
|
||||
and hasattr(config.function, "name")):
|
||||
continue
|
||||
if (config.type == "function"
|
||||
and config.function.name == func_name):
|
||||
if not hasattr(config.function, "parameters"):
|
||||
return {}
|
||||
params = config.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
return params["properties"]
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
return {}
|
||||
logger.warning("Tool '%s' is not defined in the tools list.",
|
||||
func_name)
|
||||
return {}
|
||||
|
||||
def convert_param_value(param_value: str, param_name: str,
|
||||
param_config: dict, func_name: str) -> Any:
|
||||
# Handle null value for any type
|
||||
if param_value.lower() == "null":
|
||||
return None
|
||||
|
||||
if param_name not in param_config:
|
||||
if param_config != {}:
|
||||
logger.warning(
|
||||
"Parsed parameter '%s' is not defined in "
|
||||
"the tool parameters for tool '%s', "
|
||||
"directly returning the string value.", param_name,
|
||||
func_name)
|
||||
return param_value
|
||||
|
||||
if (isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]):
|
||||
param_type = str(
|
||||
param_config[param_name]["type"]).strip().lower()
|
||||
else:
|
||||
param_type = "string"
|
||||
if param_type in [
|
||||
"string", "str", "text", "varchar", "char", "enum"
|
||||
]:
|
||||
return param_value
|
||||
elif (param_type.startswith("int") or param_type.startswith("uint")
|
||||
or param_type.startswith("long")
|
||||
or param_type.startswith("short")
|
||||
or param_type.startswith("unsigned")):
|
||||
try:
|
||||
param_value = int(param_value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not an integer in tool "
|
||||
"'%s', degenerating to string.", param_value,
|
||||
param_name, func_name)
|
||||
return param_value
|
||||
elif param_type.startswith("num") or param_type.startswith(
|
||||
"float"):
|
||||
try:
|
||||
float_param_value = float(param_value)
|
||||
param_value = float_param_value if float_param_value - int(
|
||||
float_param_value) != 0 else int(
|
||||
float_param_value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a float in tool "
|
||||
"'%s', degenerating to string.", param_value,
|
||||
param_name, func_name)
|
||||
return param_value
|
||||
elif param_type in ["boolean", "bool", "binary"]:
|
||||
param_value = param_value.lower()
|
||||
if param_value not in ["true", "false"]:
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a boolean "
|
||||
"(`true` of `false`) in tool '%s', degenerating to false.",
|
||||
param_value, param_name, func_name)
|
||||
return param_value == "true"
|
||||
else:
|
||||
if param_type == "object" or param_type.startswith("dict"):
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
return param_value
|
||||
except (ValueError, TypeError, json.JSONDecodeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a valid JSON "
|
||||
"object in tool '%s', will try other methods to parse it.",
|
||||
param_value, param_name, func_name)
|
||||
try:
|
||||
param_value = ast.literal_eval(param_value)
|
||||
except (ValueError, SyntaxError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' cannot be converted via "
|
||||
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
|
||||
param_value, param_name, func_name)
|
||||
return param_value
|
||||
|
||||
# Extract function name
|
||||
end_index = function_call_str.index(">")
|
||||
function_name = function_call_str[:end_index]
|
||||
param_config = get_arguments_config(function_name)
|
||||
parameters = function_call_str[end_index + 1:]
|
||||
param_dict = {}
|
||||
for match in self.tool_call_parameter_regex.findall(parameters):
|
||||
match_text = match[0] if match[0] else match[1]
|
||||
idx = match_text.index(">")
|
||||
param_name = match_text[:idx]
|
||||
param_value = str(match_text[idx + 1:])
|
||||
# Remove prefix and trailing \n
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = convert_param_value(
|
||||
param_value, param_name, param_config, function_name)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(param_dict,
|
||||
ensure_ascii=False)),
|
||||
)
|
||||
|
||||
def _get_function_calls(self, model_output: str) -> list[str]:
|
||||
# Find all tool calls
|
||||
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||
raw_tool_calls = [
|
||||
match[0] if match[0] else match[1] for match in matched_ranges
|
||||
]
|
||||
|
||||
# Back-off strategy if no tool_call tags found
|
||||
if len(raw_tool_calls) == 0:
|
||||
raw_tool_calls = [model_output]
|
||||
|
||||
raw_function_calls = []
|
||||
for tool_call in raw_tool_calls:
|
||||
raw_function_calls.extend(
|
||||
self.tool_call_function_regex.findall(tool_call))
|
||||
|
||||
function_calls = [
|
||||
match[0] if match[0] else match[1] for match in raw_function_calls
|
||||
]
|
||||
return function_calls
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# Quick check to avoid unnecessary processing
|
||||
if self.tool_call_prefix not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
# Check if both think start and end tokens are present
|
||||
if (self.think_start_token in model_output
|
||||
and self.think_end_token in model_output):
|
||||
# Find the position of think end token
|
||||
think_end_index = model_output.find(self.think_end_token) + len(
|
||||
self.think_end_token)
|
||||
# Extract content after think end token
|
||||
result_content = model_output[think_end_index:]
|
||||
thinking_content = model_output[:think_end_index]
|
||||
else:
|
||||
thinking_content = ""
|
||||
result_content = model_output
|
||||
|
||||
try:
|
||||
function_calls = self._get_function_calls(result_content)
|
||||
if len(function_calls) == 0:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
tool_calls = [
|
||||
self._parse_xml_function_call(function_call_str, request.tools)
|
||||
for function_call_str in function_calls
|
||||
]
|
||||
|
||||
# Populate prev_tool_call_arr for serving layer to set finish_reason
|
||||
self.prev_tool_call_arr.clear() # Clear previous calls
|
||||
for tool_call in tool_calls:
|
||||
if tool_call:
|
||||
self.prev_tool_call_arr.append({
|
||||
"name":
|
||||
tool_call.function.name,
|
||||
"arguments":
|
||||
tool_call.function.arguments,
|
||||
})
|
||||
|
||||
# Extract content before tool calls
|
||||
tool_call_start_index = result_content.find(
|
||||
self.tool_call_start_token)
|
||||
tool_call_start_index = (
|
||||
tool_call_start_index if tool_call_start_index >= 0 else
|
||||
result_content.find(self.tool_call_prefix))
|
||||
content = thinking_content + result_content[:tool_call_start_index]
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=(len(tool_calls) > 0),
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
# If no delta text, return None unless
|
||||
# it's an EOS token after tool calls
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
# We check for tool calls in the text even if is_tool_call_started
|
||||
# is False because it might have been reset after processing all tools
|
||||
if (delta_token_ids
|
||||
and self.tool_call_end_token_id not in delta_token_ids):
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text))
|
||||
|
||||
# If we have completed tool calls and populated prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token) - current_text.count(
|
||||
self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta message to allow finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Check if this is the first call (reset state if needed)
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Update accumulated text
|
||||
self.accumulated_text = current_text
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
tool_ends = current_text.count(self.tool_call_end_token)
|
||||
if tool_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
|
||||
# Check if there are more tool calls
|
||||
if self.current_tool_index >= current_text.count(
|
||||
self.tool_call_start_token):
|
||||
# No more tool calls
|
||||
self.is_tool_call_started = False
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Check if end thinking
|
||||
if (not self.is_thinking_end
|
||||
and (self.think_end_token_id in delta_token_ids
|
||||
or self.think_end_token in delta_text)):
|
||||
self.is_thinking_end = True
|
||||
|
||||
# If thinking hasn't ended yet, don't process any tool calls
|
||||
if not self.is_thinking_end:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if (self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text):
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[:delta_text.index(
|
||||
self.tool_call_start_token)]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
# Count tool calls we've seen vs processed
|
||||
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# We're in a tool call, find the current tool call portion
|
||||
# Need to find the correct tool call based on current_tool_index
|
||||
# Only process tool calls after think_end_token
|
||||
think_end_index = current_text.find(self.think_end_token) + len(
|
||||
self.think_end_token
|
||||
) if self.think_end_token in current_text else 0
|
||||
tool_starts: list[int] = []
|
||||
idx = think_end_index
|
||||
while True:
|
||||
idx = current_text.find(self.tool_call_start_token, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
tool_starts.append(idx)
|
||||
idx += len(self.tool_call_start_token)
|
||||
|
||||
if self.current_tool_index >= len(tool_starts):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
tool_start_idx = tool_starts[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
tool_end_idx = current_text.find(self.tool_call_end_token,
|
||||
tool_start_idx)
|
||||
if tool_end_idx == -1:
|
||||
tool_text = current_text[tool_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[tool_start_idx:tool_end_idx +
|
||||
len(self.tool_call_end_token)]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.tool_call_prefix in tool_text:
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix)
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
self.current_function_name = tool_text[func_start:func_end]
|
||||
self.current_tool_id = self._generate_tool_call_id(
|
||||
) # type: ignore
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
|
||||
# This ensures finish_reason="tool_calls" even if parsing isn't complete
|
||||
already_added = any(
|
||||
tool.get("name") == self.current_function_name
|
||||
for tool in self.prev_tool_call_arr)
|
||||
if not already_added:
|
||||
self.prev_tool_call_arr.append({
|
||||
"name": self.current_function_name,
|
||||
"arguments":
|
||||
"{}", # Placeholder, will be updated later
|
||||
})
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""),
|
||||
type="function",
|
||||
)
|
||||
])
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if (not self.json_started
|
||||
and self.parameter_prefix not in delta_text):
|
||||
self.json_started = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
])
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.function_end_token in tool_text:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract the complete tool call to update prev_tool_call_arr with final arguments
|
||||
# Find the function content
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix)
|
||||
func_content_end = tool_text.find(self.function_end_token,
|
||||
func_start)
|
||||
if func_content_end != -1:
|
||||
func_content = tool_text[func_start:func_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
parsed_tool = self._parse_xml_function_call(
|
||||
func_content, request.tools if request else None)
|
||||
if parsed_tool:
|
||||
# Update existing entry in prev_tool_call_arr with complete arguments
|
||||
for i, tool in enumerate(self.prev_tool_call_arr):
|
||||
if tool.get(
|
||||
"name") == parsed_tool.function.name:
|
||||
self.prev_tool_call_arr[i]["arguments"] = (
|
||||
parsed_tool.function.arguments)
|
||||
break
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool arguments during streaming.",
|
||||
exc_info=True)
|
||||
|
||||
result = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
])
|
||||
|
||||
# Reset state for next tool
|
||||
self.in_function = False
|
||||
self.json_closed = True
|
||||
|
||||
return result
|
||||
|
||||
# Look for parameters
|
||||
# Count how many complete parameters we have processed
|
||||
complete_params = tool_text.count(self.parameter_end_token)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if not self.in_param and self.param_count < complete_params:
|
||||
# Find the unprocessed parameter
|
||||
# Count parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
if len(param_starts) > self.param_count:
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
self.current_param_name = remaining[:name_end]
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(
|
||||
self.parameter_end_token)
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Build complete JSON fragment for this parameter
|
||||
if self.param_count == 0:
|
||||
json_fragment = (
|
||||
'"' + self.current_param_name + '": "' +
|
||||
json.dumps(param_value)[1:-1] + '"')
|
||||
else:
|
||||
json_fragment = (
|
||||
', "' + self.current_param_name + '": "' +
|
||||
json.dumps(param_value)[1:-1] + '"')
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=json_fragment),
|
||||
)
|
||||
])
|
||||
|
||||
# Continue parameter value
|
||||
if self.in_param:
|
||||
if self.parameter_end_token in delta_text:
|
||||
# End of parameter
|
||||
end_idx = delta_text.find(self.parameter_end_token)
|
||||
value_chunk = delta_text[:end_idx]
|
||||
|
||||
# Skip past > if at start
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1:]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith(
|
||||
"\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
# Calculate incremental JSON
|
||||
full_value = self.current_param_value + value_chunk
|
||||
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
|
||||
if self.current_param_value else "")
|
||||
full_escaped = json.dumps(full_value)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped):]
|
||||
|
||||
self.in_param = False
|
||||
self.current_param_value = ""
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped + '"'),
|
||||
)
|
||||
])
|
||||
else:
|
||||
# Continue accumulating value
|
||||
value_chunk = delta_text
|
||||
|
||||
# Handle first chunk after param name
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1:]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith(
|
||||
"\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
if value_chunk:
|
||||
# Stream the escaped delta
|
||||
prev_escaped = (json.dumps(
|
||||
self.current_param_value)[1:-1]
|
||||
if self.current_param_value else "")
|
||||
self.current_param_value += value_chunk
|
||||
full_escaped = json.dumps(
|
||||
self.current_param_value)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped):]
|
||||
|
||||
if delta_escaped:
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped),
|
||||
)
|
||||
])
|
||||
|
||||
return None
|
||||
296
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
Normal file
296
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["step3"])
|
||||
class Step3ToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for a model that uses a specific XML-like format for tool calls.
|
||||
This version uses a robust, stateful, cursor-based streaming parser and
|
||||
consolidates tool arguments into a single message.
|
||||
"""
|
||||
|
||||
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
|
||||
TOOL_CALLS_END = "<|tool_calls_end|>"
|
||||
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
|
||||
TOOL_CALL_END = "<|tool_call_end|>"
|
||||
TOOL_SEP = "<|tool_sep|>"
|
||||
SPECIAL_TOKENS = [
|
||||
TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
|
||||
]
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
# Explicit state flags for robust streaming
|
||||
self.tool_block_started = False
|
||||
self.tool_block_finished = False
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def _parse_steptml_invoke(
|
||||
action_text: str
|
||||
) -> tuple[Optional[str], Optional[dict[str, str]]]:
|
||||
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">',
|
||||
action_text)
|
||||
if not func_name_match:
|
||||
return None, None
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
params: dict[str, str] = {}
|
||||
param_matches = re.findall(
|
||||
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
|
||||
action_text)
|
||||
for name, value in param_matches:
|
||||
params[name] = value.strip()
|
||||
return func_name, params
|
||||
|
||||
def _cast_arguments(
|
||||
self,
|
||||
func_name: str,
|
||||
params: dict[str, Any],
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict[str, Any]:
|
||||
for tool in request.tools or []:
|
||||
if tool.function.name == func_name:
|
||||
schema = tool.function.parameters or {}
|
||||
properties = schema.get("properties", {})
|
||||
for key, value in params.items():
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
prop = properties.get(key, {})
|
||||
typ = prop.get("type")
|
||||
if typ == "string":
|
||||
params[key] = value.strip()
|
||||
elif typ == "integer":
|
||||
with contextlib.suppress(ValueError):
|
||||
params[key] = int(value)
|
||||
elif typ == "number":
|
||||
with contextlib.suppress(ValueError):
|
||||
params[key] = float(value)
|
||||
elif typ == "boolean":
|
||||
lower_val = value.lower()
|
||||
params[key] = lower_val == "true" if lower_val in (
|
||||
"true", "false") else value
|
||||
elif typ == "null":
|
||||
params[key] = None if value.lower(
|
||||
) == "null" else value
|
||||
break
|
||||
return params
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# The main loop processes the stream from the last known position.
|
||||
while True:
|
||||
if self.position >= len(current_text):
|
||||
return None # We've processed the entire stream.
|
||||
|
||||
unprocessed_text = current_text[self.position:]
|
||||
|
||||
# STATE: After all tools are done, all subsequent text is content.
|
||||
if self.tool_block_finished:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=unprocessed_text)
|
||||
|
||||
# STATE: Before the tool block has started.
|
||||
if not self.tool_block_started:
|
||||
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
|
||||
self.position += len(self.TOOL_CALLS_BEGIN)
|
||||
self.tool_block_started = True
|
||||
continue # Token consumed, re-loop.
|
||||
|
||||
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
|
||||
if start_pos == -1:
|
||||
if self.TOOL_CALLS_BEGIN.startswith(
|
||||
unprocessed_text.strip()) and unprocessed_text:
|
||||
return None # It's a prefix, wait.
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=unprocessed_text)
|
||||
else:
|
||||
content = unprocessed_text[:start_pos]
|
||||
self.position += len(content)
|
||||
return DeltaMessage(content=content)
|
||||
|
||||
# STATE: Inside the main tool block.
|
||||
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
|
||||
unprocessed_text = unprocessed_text.lstrip()
|
||||
self.position += offset
|
||||
|
||||
if unprocessed_text.startswith(self.TOOL_CALLS_END):
|
||||
self.position += len(self.TOOL_CALLS_END)
|
||||
self.tool_block_finished = True
|
||||
self.current_tool_id = -1
|
||||
continue
|
||||
|
||||
# Check if we are between tool calls.
|
||||
tool_finished = (
|
||||
self.current_tool_id != -1 and
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("finished"))
|
||||
if self.current_tool_id == -1 or tool_finished:
|
||||
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
|
||||
self.position += len(self.TOOL_CALL_BEGIN)
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
else:
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id]["finished"] = False
|
||||
continue
|
||||
|
||||
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
|
||||
return None
|
||||
|
||||
# STATE: Parsing an active tool call.
|
||||
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("finished", False):
|
||||
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
|
||||
if end_tool_pos == -1:
|
||||
tool_body = unprocessed_text
|
||||
else:
|
||||
tool_body = unprocessed_text[:end_tool_pos]
|
||||
|
||||
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
|
||||
tool_body):
|
||||
return None
|
||||
|
||||
function_name, arguments = self._parse_steptml_invoke(
|
||||
tool_body)
|
||||
if not function_name:
|
||||
return None
|
||||
|
||||
tool_call_arr = {
|
||||
"name": function_name,
|
||||
"parameters": arguments or {}
|
||||
}
|
||||
|
||||
# Send the function name as soon as it's parsed.
|
||||
if not self.current_tool_name_sent:
|
||||
self.current_tool_name_sent = True
|
||||
self.prev_tool_call_arr[self.current_tool_id].update(
|
||||
tool_call_arr)
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name))
|
||||
])
|
||||
|
||||
# Update our internal state with the latest parsed arguments.
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id].update( # noqa: E501
|
||||
tool_call_arr)
|
||||
|
||||
# Only send arguments when the tool call is complete.
|
||||
if end_tool_pos != -1:
|
||||
self.position += end_tool_pos + len(self.TOOL_CALL_END)
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id]["finished"] = True
|
||||
|
||||
final_args = self._cast_arguments(
|
||||
function_name,
|
||||
tool_call_arr.get("parameters", {}), # type: ignore
|
||||
request)
|
||||
if final_args:
|
||||
final_args_json = json.dumps(final_args,
|
||||
ensure_ascii=False)
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=final_args_json))
|
||||
])
|
||||
|
||||
# If tool is not finished, return None to wait for more tokens.
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.TOOL_CALLS_BEGIN not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
|
||||
if self.TOOL_CALLS_END not in rest:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
|
||||
content = (pre_text + post_text).strip()
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
|
||||
|
||||
for part in call_parts:
|
||||
if not part or self.TOOL_CALL_END not in part:
|
||||
continue
|
||||
|
||||
call_content = part.split(self.TOOL_CALL_END, 1)[0]
|
||||
if self.TOOL_SEP not in call_content:
|
||||
continue
|
||||
|
||||
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
|
||||
if type_part.strip() != "function":
|
||||
continue
|
||||
|
||||
function_name, params_dict = self._parse_steptml_invoke(
|
||||
invoke_part)
|
||||
|
||||
if function_name and params_dict is not None:
|
||||
params_dict = self._cast_arguments(function_name, params_dict,
|
||||
request)
|
||||
params_str = json.dumps(params_dict, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ToolCall(function=FunctionCall(name=function_name,
|
||||
arguments=params_str)))
|
||||
if tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
124
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
124
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, '', 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string: str, substring: str) -> list[int]:
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecoder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def consume_space(i: int, s: str) -> int:
|
||||
while i < len(s) and s[i].isspace():
|
||||
i += 1
|
||||
return i
|
||||
524
vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
Normal file
524
vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
Normal file
@@ -0,0 +1,524 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("xlam")
|
||||
class xLAMToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize state for streaming mode
|
||||
self.prev_tool_calls: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args: list[str] = [
|
||||
] # Track arguments sent for each tool
|
||||
|
||||
# For backward compatibility with tests
|
||||
self.current_tools_sent: list[bool] = []
|
||||
|
||||
# For backward compatibility with serving code
|
||||
self.prev_tool_call_arr = []
|
||||
|
||||
# Regex patterns for preprocessing
|
||||
self.json_code_block_patterns = [
|
||||
r"```(?:json)?\s*([\s\S]*?)```",
|
||||
r"\[TOOL_CALLS\]([\s\S]*?)(?=\n|$)",
|
||||
r"<tool_call>([\s\S]*?)</tool_call>",
|
||||
]
|
||||
self.thinking_tag_pattern = r"</think>([\s\S]*)"
|
||||
|
||||
# Define streaming state type to be initialized later
|
||||
self.streaming_state: dict[str, Any] = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [],
|
||||
}
|
||||
|
||||
def preprocess_model_output(
|
||||
self, model_output: str) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Preprocess the model output to extract content and potential tool calls.
|
||||
Returns:
|
||||
Tuple of (content, potential_tool_calls_json)
|
||||
"""
|
||||
# Check for thinking tag
|
||||
thinking_match = re.search(self.thinking_tag_pattern, model_output)
|
||||
if thinking_match:
|
||||
content = model_output[:thinking_match.start() +
|
||||
len("</think>")].strip()
|
||||
thinking_content = thinking_match.group(1).strip()
|
||||
|
||||
# Try to parse the thinking content as JSON
|
||||
try:
|
||||
json.loads(thinking_content)
|
||||
return content, thinking_content
|
||||
except json.JSONDecodeError:
|
||||
# If can't parse as JSON, look for JSON code blocks
|
||||
for json_pattern in self.json_code_block_patterns:
|
||||
json_matches = re.findall(json_pattern, thinking_content)
|
||||
if json_matches:
|
||||
for json_str in json_matches:
|
||||
try:
|
||||
json.loads(json_str)
|
||||
return content, json_str
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Check for JSON code blocks in the entire output
|
||||
for json_pattern in self.json_code_block_patterns:
|
||||
json_matches = re.findall(json_pattern, model_output)
|
||||
if json_matches:
|
||||
for json_str in json_matches:
|
||||
try:
|
||||
json.loads(json_str)
|
||||
# Extract content by removing the JSON code block
|
||||
content = re.sub(json_pattern, "",
|
||||
model_output).strip()
|
||||
return content, json_str
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# If the entire output is a valid JSON array or looks like one, treat it as tool calls
|
||||
if model_output.strip().startswith("["):
|
||||
try:
|
||||
json.loads(model_output)
|
||||
return None, model_output
|
||||
except json.JSONDecodeError:
|
||||
# Even if it's not valid JSON yet, it might be a tool call in progress
|
||||
if ("{" in model_output and "name" in model_output
|
||||
and "arguments" in model_output):
|
||||
return None, model_output
|
||||
|
||||
# If no tool calls found, return the original output as content
|
||||
return model_output, None
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract tool calls from a complete model output.
|
||||
"""
|
||||
try:
|
||||
# Preprocess the model output
|
||||
content, potential_tool_calls = self.preprocess_model_output(
|
||||
model_output)
|
||||
|
||||
if not potential_tool_calls:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=content)
|
||||
|
||||
# Parse the potential tool calls as JSON
|
||||
tool_calls_data = json.loads(potential_tool_calls)
|
||||
|
||||
# Ensure it's an array
|
||||
if not isinstance(tool_calls_data, list):
|
||||
logger.debug("Tool calls data is not an array")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=content or model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
for idx, call in enumerate(tool_calls_data):
|
||||
if (not isinstance(call, dict) or "name" not in call
|
||||
or "arguments" not in call):
|
||||
logger.debug("Invalid tool call format at index %d", idx)
|
||||
continue
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{idx}_{random_uuid()}",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=call["name"],
|
||||
arguments=(json.dumps(call["arguments"]) if isinstance(
|
||||
call["arguments"], dict) else call["arguments"]),
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting tool calls: %s", str(e))
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract tool calls for streaming mode.
|
||||
"""
|
||||
# First, check for a definitive start of a tool call block.
|
||||
# This prevents premature parsing of incomplete output.
|
||||
stripped_text = current_text.strip()
|
||||
preprocessed_content, preprocessed_tool_calls = (
|
||||
self.preprocess_model_output(current_text))
|
||||
|
||||
# For JSON code blocks, we need to detect them earlier, even if incomplete
|
||||
has_potential_json_block = ("```json" in current_text
|
||||
or "```\n[" in current_text
|
||||
or "[TOOL_CALLS]" in current_text
|
||||
or "<tool_call>" in current_text)
|
||||
|
||||
is_tool_call_block = (
|
||||
stripped_text.startswith("[")
|
||||
or stripped_text.startswith("<tool_call>")
|
||||
or stripped_text.startswith("[TOOL_CALLS]") or
|
||||
# Check if we have thinking tags with JSON-like content following
|
||||
("</think>[" in current_text) or
|
||||
# Check if the text contains a JSON array after preprocessing
|
||||
preprocessed_tool_calls is not None or
|
||||
# For JSON code blocks, detect early if we see enough structure
|
||||
(has_potential_json_block and '"name"' in current_text
|
||||
and '"arguments"' in current_text))
|
||||
|
||||
if not is_tool_call_block:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# Initialize streaming state if not exists
|
||||
if not hasattr(self, "streaming_state"):
|
||||
self.streaming_state = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [], # Track complete state of each tool
|
||||
}
|
||||
|
||||
# Try parsing as JSON to check for complete tool calls
|
||||
try:
|
||||
# Use preprocessed tool calls if available
|
||||
tool_calls_text = (preprocessed_tool_calls if
|
||||
preprocessed_tool_calls else current_text)
|
||||
parsed_tools = json.loads(tool_calls_text)
|
||||
if isinstance(parsed_tools, list):
|
||||
# Update our tool array for next time
|
||||
self.prev_tool_call_arr = parsed_tools
|
||||
except json.JSONDecodeError:
|
||||
# Not complete JSON yet, use regex for partial parsing
|
||||
pass
|
||||
|
||||
# Check for test-specific state setup (current_tools_sent)
|
||||
# This handles the case where tests manually set current_tools_sent
|
||||
if (hasattr(self, "current_tools_sent") # type: ignore
|
||||
and len(self.current_tools_sent) > 0):
|
||||
# If current_tools_sent is set to [False], it means the test wants us to send the name
|
||||
if (len(self.current_tools_sent) == 1
|
||||
and self.current_tools_sent[0] is False):
|
||||
# Extract the function name using regex
|
||||
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
||||
name_match = re.search(name_pattern, current_text)
|
||||
if name_match:
|
||||
function_name = name_match.group(1)
|
||||
|
||||
# The test expects us to send just the name first
|
||||
tool_id = make_tool_call_id()
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True), # type: ignore
|
||||
)
|
||||
])
|
||||
# Update state to reflect that we've sent the name
|
||||
self.current_tools_sent = [True]
|
||||
self.current_tool_id = 0
|
||||
self.streaming_state["current_tool_index"] = 0
|
||||
if len(self.streaming_state["sent_tools"]) == 0:
|
||||
self.streaming_state["sent_tools"].append({
|
||||
"sent_name":
|
||||
True,
|
||||
"sent_arguments_prefix":
|
||||
False,
|
||||
"sent_arguments":
|
||||
"",
|
||||
})
|
||||
else:
|
||||
self.streaming_state["sent_tools"][0][
|
||||
"sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
return delta
|
||||
|
||||
# Use regex to identify tool calls in the output
|
||||
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
|
||||
search_text = (preprocessed_tool_calls
|
||||
if preprocessed_tool_calls else current_text)
|
||||
|
||||
# For JSON code blocks that aren't complete yet, try to extract the JSON content
|
||||
if not preprocessed_tool_calls and has_potential_json_block:
|
||||
# Try to extract the JSON array from within the code block
|
||||
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
|
||||
current_text)
|
||||
if json_match:
|
||||
potential_json = json_match.group(1).strip()
|
||||
# Use this as search text even if it's incomplete
|
||||
if potential_json.startswith("[") and (
|
||||
'"name"' in potential_json
|
||||
and '"arguments"' in potential_json):
|
||||
search_text = potential_json
|
||||
|
||||
# Try to find complete tool names first
|
||||
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
||||
name_matches = list(re.finditer(name_pattern, search_text))
|
||||
tool_count = len(name_matches)
|
||||
|
||||
# If no complete tool names found, check for partial tool names
|
||||
if tool_count == 0:
|
||||
# Check if we're in the middle of parsing a tool name
|
||||
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
|
||||
partial_matches = list(
|
||||
re.finditer(partial_name_pattern, search_text))
|
||||
if partial_matches:
|
||||
# We have a partial tool name - not ready to emit yet
|
||||
return None
|
||||
else:
|
||||
# No tools found at all
|
||||
return None
|
||||
|
||||
# Ensure our state arrays are large enough
|
||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||
self.streaming_state["sent_tools"].append({
|
||||
"sent_name":
|
||||
False,
|
||||
"sent_arguments_prefix":
|
||||
False,
|
||||
"sent_arguments":
|
||||
"",
|
||||
})
|
||||
|
||||
while len(self.streaming_state["tool_ids"]) < tool_count:
|
||||
self.streaming_state["tool_ids"].append(None)
|
||||
|
||||
# Determine if we need to move to a new tool
|
||||
current_idx = self.streaming_state["current_tool_index"]
|
||||
|
||||
# If we haven't processed any tool yet or current tool is complete, move to next
|
||||
if current_idx == -1 or current_idx < tool_count - 1:
|
||||
next_idx = current_idx + 1
|
||||
|
||||
# If tool at next_idx has not been sent yet
|
||||
if (next_idx < tool_count
|
||||
and not self.streaming_state["sent_tools"][next_idx]
|
||||
["sent_name"]):
|
||||
# Update indexes
|
||||
self.streaming_state["current_tool_index"] = next_idx
|
||||
self.current_tool_id = (
|
||||
next_idx # For backward compatibility
|
||||
)
|
||||
current_idx = next_idx
|
||||
|
||||
# Extract the tool name
|
||||
tool_name = name_matches[current_idx].group(1)
|
||||
|
||||
# Generate ID and send tool name
|
||||
tool_id = f"call_{current_idx}_{random_uuid()}"
|
||||
self.streaming_state["tool_ids"][current_idx] = tool_id
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_name).model_dump(
|
||||
exclude_none=True), # type: ignore
|
||||
)
|
||||
])
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_name"] = True
|
||||
self.current_tool_name_sent = (
|
||||
True # For backward compatibility
|
||||
)
|
||||
|
||||
# Keep track of streamed args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
|
||||
return delta
|
||||
|
||||
# Process arguments for the current tool
|
||||
if current_idx >= 0 and current_idx < tool_count:
|
||||
# Support both regular and empty argument objects
|
||||
# First, check for the empty arguments case: "arguments": {}
|
||||
empty_args_pattern = (
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
|
||||
empty_args_match = re.search(empty_args_pattern, search_text)
|
||||
|
||||
# Check if this tool has empty arguments
|
||||
if empty_args_match and empty_args_match.start() > 0:
|
||||
# Find which tool this empty arguments belongs to
|
||||
empty_args_tool_idx = 0
|
||||
for i in range(tool_count):
|
||||
if i == current_idx:
|
||||
# If this is our current tool and it has empty arguments
|
||||
if not self.streaming_state["sent_tools"][
|
||||
current_idx]["sent_arguments_prefix"]:
|
||||
# Send empty object
|
||||
self.streaming_state["sent_tools"][
|
||||
current_idx][
|
||||
"sent_arguments_prefix"] = True
|
||||
self.streaming_state["sent_tools"][
|
||||
current_idx]["sent_arguments"] = "{}"
|
||||
|
||||
# Update streamed_args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{}"
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{}").
|
||||
model_dump(
|
||||
exclude_none=True), # type: ignore
|
||||
)
|
||||
])
|
||||
|
||||
# Move to next tool if available
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state[
|
||||
"current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"]
|
||||
|
||||
return delta
|
||||
|
||||
# Extract arguments for current tool using regex for non-empty arguments
|
||||
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||
args_matches = list(re.finditer(args_pattern, search_text))
|
||||
|
||||
if current_idx < len(args_matches):
|
||||
args_text = args_matches[current_idx].group(1)
|
||||
|
||||
# Handle transition between tools
|
||||
is_last_tool = current_idx == tool_count - 1
|
||||
|
||||
# For multiple tools, extract only the arguments for the current tool
|
||||
if tool_count > 1:
|
||||
# Parse the entire JSON structure to properly extract arguments for each tool
|
||||
try:
|
||||
parsed_tools = json.loads(search_text)
|
||||
if isinstance(
|
||||
parsed_tools,
|
||||
list) and current_idx < len(parsed_tools):
|
||||
current_tool = parsed_tools[current_idx]
|
||||
if isinstance(current_tool.get("arguments"),
|
||||
dict):
|
||||
args_text = json.dumps(
|
||||
current_tool["arguments"])
|
||||
else:
|
||||
args_text = str(
|
||||
current_tool.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
# Fallback to regex-based extraction
|
||||
pass
|
||||
|
||||
# If arguments haven't been sent yet
|
||||
sent_args = self.streaming_state["sent_tools"][
|
||||
current_idx]["sent_arguments"]
|
||||
|
||||
# If we haven't sent the opening bracket yet
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"] and args_text.startswith(
|
||||
"{"):
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"] = "{"
|
||||
|
||||
# Update streamed_args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{"
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{").model_dump(
|
||||
exclude_none=True), # type: ignore
|
||||
)
|
||||
])
|
||||
return delta
|
||||
|
||||
# If we need to send more arguments
|
||||
if args_text.startswith(sent_args):
|
||||
# Calculate what part of arguments we need to send
|
||||
args_diff = args_text[len(sent_args):]
|
||||
|
||||
if args_diff:
|
||||
# Update our state
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"] = args_text
|
||||
|
||||
# Update streamed_args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += args_diff
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=args_diff).model_dump(
|
||||
exclude_none=True), # type: ignore
|
||||
)
|
||||
])
|
||||
return delta
|
||||
|
||||
# If the tool's arguments are complete, check if we need to move to the next tool
|
||||
if args_text.endswith("}") and args_text == sent_args:
|
||||
# This tool is complete, move to the next one in the next iteration
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"] # For compatibility
|
||||
|
||||
# If we got here, we couldn't determine what to stream next
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in streaming tool calls: {e}")
|
||||
# If we encounter an error, just return the delta text as regular content
|
||||
return DeltaMessage(content=delta_text)
|
||||
Reference in New Issue
Block a user