初始化项目,由ModelHub XC社区提供模型
Model: domyn/Domyn-Small-v1.0 Source: Original Platform
This commit is contained in:
314
tool_parser_plugin.py
Normal file
314
tool_parser_plugin.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Custom vLLM tool parser plugin for models that use <tool_call> XML tags.
|
||||
|
||||
The model outputs tool calls in this format:
|
||||
<tool_call>
|
||||
{"name": "function_name", "arguments": {"arg1": "val1"}}
|
||||
</tool_call>
|
||||
|
||||
Multiple tool calls can appear in a single response (parallel tool calling).
|
||||
|
||||
Usage:
|
||||
vllm serve <model> \
|
||||
--enable-auto-tool-choice \
|
||||
--tool-parser-plugin /absolute/path/to/tool_parser_plugin.py \
|
||||
--tool-call-parser xml_tool_call \
|
||||
--chat-template /absolute/path/to/tool_chat_template.jinja
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import compatibility: vLLM >=0.8 moved tool_parsers to vllm.tool_parsers;
|
||||
# older versions keep them under vllm.entrypoints.openai.tool_parsers.
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
# Newer vLLM, roughly 0.15+
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
except ImportError:
|
||||
# Older vLLM
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
|
||||
except ImportError:
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _generate_tool_call_id() -> str:
|
||||
"""Generate a unique tool-call ID in the format expected by OpenAI."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Register the parser so it can be referenced via --tool-call-parser
|
||||
# ---------------------------------------------------------------------------
|
||||
@ToolParserManager.register_module(["xml_tool_call"])
|
||||
class XMLToolCallParser(ToolParser):
|
||||
"""
|
||||
Parses tool calls wrapped in <tool_call>...</tool_call> XML tags.
|
||||
|
||||
Handles both single and parallel (multiple) tool calls in one response.
|
||||
Supports streaming and non-streaming extraction.
|
||||
"""
|
||||
|
||||
# Regex to match complete <tool_call>...</tool_call> blocks
|
||||
TOOL_CALL_RE = re.compile(
|
||||
r"<tool_call>\s*(.*?)\s*</tool_call>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Regex that also matches an incomplete (still-streaming) block
|
||||
TOOL_CALL_OPEN_RE = re.compile(
|
||||
r"<tool_call>\s*(.*?)(?:</tool_call>|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
TOOL_CALL_START = "<tool_call>"
|
||||
TOOL_CALL_END = "</tool_call>"
|
||||
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
# vLLM newer versions: ToolParser.__init__(tokenizer, tools)
|
||||
# vLLM older versions: ToolParser.__init__(tokenizer)
|
||||
try:
|
||||
super().__init__(tokenizer, tools)
|
||||
except TypeError:
|
||||
super().__init__(tokenizer)
|
||||
self.tools = tools or []
|
||||
|
||||
# ---- streaming state ----
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional: adjust the request before inference
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _parse_tool_json(raw: str) -> dict | None:
|
||||
"""Parse a tool call JSON block, handling Python-style single quotes."""
|
||||
# Try standard JSON first
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
# Fall back to ast.literal_eval for Python-style dicts with single quotes
|
||||
try:
|
||||
result = ast.literal_eval(raw)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
return None
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionRequest:
|
||||
return request
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# NON-STREAMING extraction
|
||||
# ------------------------------------------------------------------
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Parse all <tool_call>...</tool_call> blocks from the full model
|
||||
output and convert them to OpenAI ToolCall objects.
|
||||
"""
|
||||
|
||||
# Find all complete tool-call blocks
|
||||
raw_matches = self.TOOL_CALL_RE.findall(model_output)
|
||||
|
||||
if not raw_matches:
|
||||
# No tool calls found — return the text as-is
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
for raw_json in raw_matches:
|
||||
parsed = self._parse_tool_json(raw_json)
|
||||
if parsed is None:
|
||||
logger.warning(
|
||||
"Failed to parse tool call JSON: %s", raw_json
|
||||
)
|
||||
continue
|
||||
|
||||
fn_name = parsed.get("name", "")
|
||||
fn_args = parsed.get("arguments", {})
|
||||
|
||||
# Ensure arguments is a JSON string (OpenAI format)
|
||||
if isinstance(fn_args, dict):
|
||||
fn_args_str = json.dumps(fn_args)
|
||||
elif isinstance(fn_args, str):
|
||||
# Model may emit arguments as a JSON string — validate and pass through
|
||||
try:
|
||||
json.loads(fn_args)
|
||||
fn_args_str = fn_args
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Try ast.literal_eval for Python-style dicts (e.g. single quotes,
|
||||
# unquoted keys). If that also fails, emit an empty dict so
|
||||
# downstream json.loads never sees an invalid string.
|
||||
try:
|
||||
recovered = ast.literal_eval(fn_args)
|
||||
fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({})
|
||||
except (ValueError, SyntaxError):
|
||||
fn_args_str = "{}"
|
||||
else:
|
||||
fn_args_str = str(fn_args)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=_generate_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=fn_name,
|
||||
arguments=fn_args_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Strip tool-call blocks from content to get any surrounding text
|
||||
remaining_content = self.TOOL_CALL_RE.sub("", model_output).strip()
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=remaining_content if remaining_content else None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# STREAMING extraction
|
||||
# ------------------------------------------------------------------
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Incrementally parse tool calls from the streaming token output.
|
||||
|
||||
Strategy:
|
||||
- Before seeing <tool_call>, stream tokens as regular content.
|
||||
- Once <tool_call> is detected, buffer until </tool_call>.
|
||||
- On </tool_call>, emit the complete tool call delta.
|
||||
- Support multiple sequential tool calls.
|
||||
"""
|
||||
|
||||
# If we haven't seen a tool_call opening tag yet, pass through as
|
||||
# regular content (unless the start tag is partially forming).
|
||||
if self.TOOL_CALL_START not in current_text:
|
||||
# Check if the current text ends with a partial match of the
|
||||
# start tag — if so, hold back to avoid emitting partial tags.
|
||||
for i in range(1, len(self.TOOL_CALL_START)):
|
||||
if current_text.endswith(self.TOOL_CALL_START[:i]):
|
||||
# Possibly forming the start tag — hold delta
|
||||
return None
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# ---- We are inside or past a <tool_call> block ----
|
||||
|
||||
# Find all *complete* tool call blocks so far
|
||||
complete_matches = self.TOOL_CALL_RE.findall(current_text)
|
||||
num_complete = len(complete_matches)
|
||||
|
||||
# Determine how many we've already streamed
|
||||
num_already_sent = len(self.prev_tool_call_arr)
|
||||
|
||||
if num_complete > num_already_sent:
|
||||
# A new tool call just completed — emit it
|
||||
new_raw = complete_matches[num_already_sent]
|
||||
parsed = self._parse_tool_json(new_raw)
|
||||
if parsed is None:
|
||||
logger.warning(
|
||||
"Streaming: failed to parse tool call JSON: %s",
|
||||
new_raw,
|
||||
)
|
||||
return None
|
||||
|
||||
fn_name = parsed.get("name", "")
|
||||
fn_args = parsed.get("arguments", {})
|
||||
if isinstance(fn_args, dict):
|
||||
fn_args_str = json.dumps(fn_args)
|
||||
elif isinstance(fn_args, str):
|
||||
try:
|
||||
json.loads(fn_args)
|
||||
fn_args_str = fn_args
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
try:
|
||||
recovered = ast.literal_eval(fn_args)
|
||||
fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({})
|
||||
except (ValueError, SyntaxError):
|
||||
fn_args_str = "{}"
|
||||
else:
|
||||
fn_args_str = str(fn_args)
|
||||
|
||||
self.current_tool_id += 1
|
||||
self.prev_tool_call_arr.append(parsed)
|
||||
self.streamed_args_for_tool.append(fn_args_str)
|
||||
self.current_tool_name_sent = True
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
id=_generate_tool_call_id(),
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=fn_name,
|
||||
arguments=fn_args_str,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# If we're currently inside an incomplete tool call block,
|
||||
# don't emit anything — wait for it to complete.
|
||||
# Check if there's an open <tool_call> without a matching close
|
||||
open_count = current_text.count(self.TOOL_CALL_START)
|
||||
close_count = current_text.count(self.TOOL_CALL_END)
|
||||
if open_count > close_count:
|
||||
# Still buffering inside a tool call
|
||||
return None
|
||||
|
||||
# If we're past all tool call blocks, stream remaining content
|
||||
# (unlikely for most models but handles edge cases)
|
||||
return None
|
||||
Reference in New Issue
Block a user