314 lines
11 KiB
Python
314 lines
11 KiB
Python
"""
|
|
Custom vLLM tool parser plugin for models that use <tool_call> XML tags.
|
|
|
|
The model outputs tool calls in this format:
|
|
<tool_call>
|
|
{"name": "function_name", "arguments": {"arg1": "val1"}}
|
|
</tool_call>
|
|
|
|
Multiple tool calls can appear in a single response (parallel tool calling).
|
|
|
|
Usage:
|
|
vllm serve <model> \
|
|
--enable-auto-tool-choice \
|
|
--tool-parser-plugin /absolute/path/to/tool_parser_plugin.py \
|
|
--tool-call-parser xml_tool_call \
|
|
--chat-template /absolute/path/to/tool_chat_template.jinja
|
|
"""
|
|
|
|
import ast
|
|
import json
|
|
import re
|
|
import uuid
|
|
from typing import Sequence, Union
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Import compatibility: vLLM >=0.8 moved tool_parsers to vllm.tool_parsers;
|
|
# older versions keep them under vllm.entrypoints.openai.tool_parsers.
|
|
# ---------------------------------------------------------------------------
|
|
try:
|
|
# Newer vLLM, roughly 0.15+
|
|
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
|
from vllm.entrypoints.openai.engine.protocol import (
|
|
DeltaFunctionCall,
|
|
DeltaMessage,
|
|
DeltaToolCall,
|
|
ExtractedToolCallInformation,
|
|
FunctionCall,
|
|
ToolCall,
|
|
)
|
|
except ImportError:
|
|
# Older vLLM
|
|
from vllm.entrypoints.openai.protocol import (
|
|
ChatCompletionRequest,
|
|
DeltaFunctionCall,
|
|
DeltaMessage,
|
|
DeltaToolCall,
|
|
ExtractedToolCallInformation,
|
|
FunctionCall,
|
|
ToolCall,
|
|
)
|
|
|
|
try:
|
|
from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
|
|
except ImportError:
|
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
|
ToolParser,
|
|
ToolParserManager,
|
|
)
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def _generate_tool_call_id() -> str:
|
|
"""Generate a unique tool-call ID in the format expected by OpenAI."""
|
|
return f"call_{uuid.uuid4().hex[:24]}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Register the parser so it can be referenced via --tool-call-parser
|
|
# ---------------------------------------------------------------------------
|
|
@ToolParserManager.register_module(["xml_tool_call"])
|
|
class XMLToolCallParser(ToolParser):
|
|
"""
|
|
Parses tool calls wrapped in <tool_call>...</tool_call> XML tags.
|
|
|
|
Handles both single and parallel (multiple) tool calls in one response.
|
|
Supports streaming and non-streaming extraction.
|
|
"""
|
|
|
|
# Regex to match complete <tool_call>...</tool_call> blocks
|
|
TOOL_CALL_RE = re.compile(
|
|
r"<tool_call>\s*(.*?)\s*</tool_call>",
|
|
re.DOTALL,
|
|
)
|
|
|
|
# Regex that also matches an incomplete (still-streaming) block
|
|
TOOL_CALL_OPEN_RE = re.compile(
|
|
r"<tool_call>\s*(.*?)(?:</tool_call>|$)",
|
|
re.DOTALL,
|
|
)
|
|
|
|
TOOL_CALL_START = "<tool_call>"
|
|
TOOL_CALL_END = "</tool_call>"
|
|
|
|
def __init__(self, tokenizer, tools=None):
|
|
# vLLM newer versions: ToolParser.__init__(tokenizer, tools)
|
|
# vLLM older versions: ToolParser.__init__(tokenizer)
|
|
try:
|
|
super().__init__(tokenizer, tools)
|
|
except TypeError:
|
|
super().__init__(tokenizer)
|
|
self.tools = tools or []
|
|
|
|
# ---- streaming state ----
|
|
self.current_tool_id: int = -1
|
|
self.current_tool_name_sent: bool = False
|
|
self.prev_tool_call_arr: list[dict] = []
|
|
self.streamed_args_for_tool: list[str] = []
|
|
|
|
# ------------------------------------------------------------------
|
|
# Optional: adjust the request before inference
|
|
# ------------------------------------------------------------------
|
|
@staticmethod
|
|
def _parse_tool_json(raw: str) -> dict | None:
|
|
"""Parse a tool call JSON block, handling Python-style single quotes."""
|
|
# Try standard JSON first
|
|
try:
|
|
return json.loads(raw)
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
# Fall back to ast.literal_eval for Python-style dicts with single quotes
|
|
try:
|
|
result = ast.literal_eval(raw)
|
|
if isinstance(result, dict):
|
|
return result
|
|
except (ValueError, SyntaxError):
|
|
pass
|
|
return None
|
|
|
|
def adjust_request(
|
|
self, request: ChatCompletionRequest
|
|
) -> ChatCompletionRequest:
|
|
return request
|
|
|
|
# ------------------------------------------------------------------
|
|
# NON-STREAMING extraction
|
|
# ------------------------------------------------------------------
|
|
def extract_tool_calls(
|
|
self,
|
|
model_output: str,
|
|
request: ChatCompletionRequest,
|
|
) -> ExtractedToolCallInformation:
|
|
"""
|
|
Parse all <tool_call>...</tool_call> blocks from the full model
|
|
output and convert them to OpenAI ToolCall objects.
|
|
"""
|
|
|
|
# Find all complete tool-call blocks
|
|
raw_matches = self.TOOL_CALL_RE.findall(model_output)
|
|
|
|
if not raw_matches:
|
|
# No tool calls found — return the text as-is
|
|
return ExtractedToolCallInformation(
|
|
tools_called=False,
|
|
tool_calls=[],
|
|
content=model_output,
|
|
)
|
|
|
|
tool_calls: list[ToolCall] = []
|
|
for raw_json in raw_matches:
|
|
parsed = self._parse_tool_json(raw_json)
|
|
if parsed is None:
|
|
logger.warning(
|
|
"Failed to parse tool call JSON: %s", raw_json
|
|
)
|
|
continue
|
|
|
|
fn_name = parsed.get("name", "")
|
|
fn_args = parsed.get("arguments", {})
|
|
|
|
# Ensure arguments is a JSON string (OpenAI format)
|
|
if isinstance(fn_args, dict):
|
|
fn_args_str = json.dumps(fn_args)
|
|
elif isinstance(fn_args, str):
|
|
# Model may emit arguments as a JSON string — validate and pass through
|
|
try:
|
|
json.loads(fn_args)
|
|
fn_args_str = fn_args
|
|
except (json.JSONDecodeError, ValueError):
|
|
# Try ast.literal_eval for Python-style dicts (e.g. single quotes,
|
|
# unquoted keys). If that also fails, emit an empty dict so
|
|
# downstream json.loads never sees an invalid string.
|
|
try:
|
|
recovered = ast.literal_eval(fn_args)
|
|
fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({})
|
|
except (ValueError, SyntaxError):
|
|
fn_args_str = "{}"
|
|
else:
|
|
fn_args_str = str(fn_args)
|
|
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=_generate_tool_call_id(),
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=fn_name,
|
|
arguments=fn_args_str,
|
|
),
|
|
)
|
|
)
|
|
|
|
# Strip tool-call blocks from content to get any surrounding text
|
|
remaining_content = self.TOOL_CALL_RE.sub("", model_output).strip()
|
|
|
|
return ExtractedToolCallInformation(
|
|
tools_called=True,
|
|
tool_calls=tool_calls,
|
|
content=remaining_content if remaining_content else None,
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# STREAMING extraction
|
|
# ------------------------------------------------------------------
|
|
def extract_tool_calls_streaming(
|
|
self,
|
|
previous_text: str,
|
|
current_text: str,
|
|
delta_text: str,
|
|
previous_token_ids: Sequence[int],
|
|
current_token_ids: Sequence[int],
|
|
delta_token_ids: Sequence[int],
|
|
request: ChatCompletionRequest,
|
|
) -> Union[DeltaMessage, None]:
|
|
"""
|
|
Incrementally parse tool calls from the streaming token output.
|
|
|
|
Strategy:
|
|
- Before seeing <tool_call>, stream tokens as regular content.
|
|
- Once <tool_call> is detected, buffer until </tool_call>.
|
|
- On </tool_call>, emit the complete tool call delta.
|
|
- Support multiple sequential tool calls.
|
|
"""
|
|
|
|
# If we haven't seen a tool_call opening tag yet, pass through as
|
|
# regular content (unless the start tag is partially forming).
|
|
if self.TOOL_CALL_START not in current_text:
|
|
# Check if the current text ends with a partial match of the
|
|
# start tag — if so, hold back to avoid emitting partial tags.
|
|
for i in range(1, len(self.TOOL_CALL_START)):
|
|
if current_text.endswith(self.TOOL_CALL_START[:i]):
|
|
# Possibly forming the start tag — hold delta
|
|
return None
|
|
return DeltaMessage(content=delta_text)
|
|
|
|
# ---- We are inside or past a <tool_call> block ----
|
|
|
|
# Find all *complete* tool call blocks so far
|
|
complete_matches = self.TOOL_CALL_RE.findall(current_text)
|
|
num_complete = len(complete_matches)
|
|
|
|
# Determine how many we've already streamed
|
|
num_already_sent = len(self.prev_tool_call_arr)
|
|
|
|
if num_complete > num_already_sent:
|
|
# A new tool call just completed — emit it
|
|
new_raw = complete_matches[num_already_sent]
|
|
parsed = self._parse_tool_json(new_raw)
|
|
if parsed is None:
|
|
logger.warning(
|
|
"Streaming: failed to parse tool call JSON: %s",
|
|
new_raw,
|
|
)
|
|
return None
|
|
|
|
fn_name = parsed.get("name", "")
|
|
fn_args = parsed.get("arguments", {})
|
|
if isinstance(fn_args, dict):
|
|
fn_args_str = json.dumps(fn_args)
|
|
elif isinstance(fn_args, str):
|
|
try:
|
|
json.loads(fn_args)
|
|
fn_args_str = fn_args
|
|
except (json.JSONDecodeError, ValueError):
|
|
try:
|
|
recovered = ast.literal_eval(fn_args)
|
|
fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({})
|
|
except (ValueError, SyntaxError):
|
|
fn_args_str = "{}"
|
|
else:
|
|
fn_args_str = str(fn_args)
|
|
|
|
self.current_tool_id += 1
|
|
self.prev_tool_call_arr.append(parsed)
|
|
self.streamed_args_for_tool.append(fn_args_str)
|
|
self.current_tool_name_sent = True
|
|
|
|
return DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.current_tool_id,
|
|
id=_generate_tool_call_id(),
|
|
type="function",
|
|
function=DeltaFunctionCall(
|
|
name=fn_name,
|
|
arguments=fn_args_str,
|
|
),
|
|
)
|
|
]
|
|
)
|
|
|
|
# If we're currently inside an incomplete tool call block,
|
|
# don't emit anything — wait for it to complete.
|
|
# Check if there's an open <tool_call> without a matching close
|
|
open_count = current_text.count(self.TOOL_CALL_START)
|
|
close_count = current_text.count(self.TOOL_CALL_END)
|
|
if open_count > close_count:
|
|
# Still buffering inside a tool call
|
|
return None
|
|
|
|
# If we're past all tool call blocks, stream remaining content
|
|
# (unlikely for most models but handles edge cases)
|
|
return None |