model: support Step3V (#8583)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: nnnobody-code <nnnobody@foxmail.com> Co-authored-by: ispobock <ispobaoke@gmail.com> Co-authored-by: Qiaolin-Yu <qy254@cornell.edu> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com> Co-authored-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.function_call.step3_detector import Step3Detector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,6 +40,7 @@ class FunctionCallParser:
|
||||
"kimi_k2": KimiK2Detector,
|
||||
"qwen3_coder": Qwen3CoderDetector,
|
||||
"glm45": Glm4MoeDetector,
|
||||
"step3": Step3Detector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
|
||||
436
python/sglang/srt/function_call/step3_detector.py
Normal file
436
python/sglang/srt/function_call/step3_detector.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
ToolCallItem,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:
|
||||
"""Get the expected type for a function argument from tool schema."""
|
||||
name2tool = {tool.function.name: tool for tool in defined_tools}
|
||||
if func_name not in name2tool:
|
||||
return None
|
||||
tool = name2tool[func_name]
|
||||
parameters = tool.function.parameters or {}
|
||||
properties = parameters.get("properties", {})
|
||||
if arg_key not in properties:
|
||||
return None
|
||||
return properties[arg_key].get("type", None)
|
||||
|
||||
|
||||
def parse_arguments(value: str) -> tuple[Any, bool]:
|
||||
"""Parse a string value to appropriate type. Returns (parsed_value, success)."""
|
||||
try:
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except:
|
||||
parsed_value = ast.literal_eval(value)
|
||||
return parsed_value, True
|
||||
except:
|
||||
return value, False
|
||||
|
||||
|
||||
class Step3Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Step3 model function call format.
|
||||
|
||||
The Step3 format uses special Unicode tokens to delimit function calls
|
||||
with steptml XML format for invocations.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="function_name">
|
||||
<steptml:parameter name="param1">value1</steptml:parameter>
|
||||
<steptml:parameter name="param2">value2</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "<|tool_calls_begin|>"
|
||||
self.eot_token = "<|tool_calls_end|>"
|
||||
self.tool_call_begin = "<|tool_call_begin|>"
|
||||
self.tool_call_end = "<|tool_call_end|>"
|
||||
self.tool_sep = "<|tool_sep|>"
|
||||
|
||||
# Regex for parsing steptml invocations
|
||||
self.invoke_regex = re.compile(
|
||||
r'<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>', re.DOTALL
|
||||
)
|
||||
self.param_regex = re.compile(
|
||||
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', re.DOTALL
|
||||
)
|
||||
|
||||
# Streaming state variables
|
||||
self._in_tool_block: bool = False
|
||||
self._tool_block_finished: bool = False
|
||||
self._current_function_name: str = ""
|
||||
self._current_parameters: Dict[str, Any] = {}
|
||||
self._in_tool_call: bool = False
|
||||
self._function_name_sent: bool = False
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Step3 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def _parse_steptml_invoke(
|
||||
self, text: str, tools: List[Tool] = None
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse steptml invoke format to extract function name and parameters."""
|
||||
invoke_match = self.invoke_regex.search(text)
|
||||
if not invoke_match:
|
||||
return None, {}
|
||||
|
||||
func_name = invoke_match.group(1)
|
||||
params_text = invoke_match.group(2)
|
||||
|
||||
params = {}
|
||||
for param_match in self.param_regex.finditer(params_text):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
|
||||
# If tools provided, use schema-aware parsing
|
||||
if tools:
|
||||
arg_type = get_argument_type(func_name, param_name, tools)
|
||||
if arg_type and arg_type != "string":
|
||||
parsed_value, _ = parse_arguments(param_value)
|
||||
params[param_name] = parsed_value
|
||||
else:
|
||||
params[param_name] = param_value
|
||||
else:
|
||||
# Fallback to generic parsing if no tools provided
|
||||
parsed_value, _ = parse_arguments(param_value)
|
||||
params[param_name] = parsed_value
|
||||
|
||||
return func_name, params
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
"""
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
try:
|
||||
pre_text, rest = text.split(self.bot_token, 1)
|
||||
|
||||
# If no end token, return everything as normal text
|
||||
if self.eot_token not in rest:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
tool_section, post_text = rest.split(self.eot_token, 1)
|
||||
|
||||
# Find all individual tool calls using regex
|
||||
calls = []
|
||||
tool_call_pattern = (
|
||||
f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}"
|
||||
)
|
||||
|
||||
for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):
|
||||
call_content = match.group(1)
|
||||
|
||||
# Check if it's a function call
|
||||
if self.tool_sep not in call_content:
|
||||
continue
|
||||
|
||||
type_part, invoke_part = call_content.split(self.tool_sep, 1)
|
||||
if type_part.strip() != "function":
|
||||
continue
|
||||
|
||||
func_name, params = self._parse_steptml_invoke(invoke_part, tools)
|
||||
if func_name:
|
||||
# Use parse_base_json to create the ToolCallItem
|
||||
action = {"name": func_name, "arguments": params}
|
||||
calls.extend(self.parse_base_json(action, tools))
|
||||
|
||||
# Combine pre and post text
|
||||
normal_text = pre_text + post_text
|
||||
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detect_and_parse: {e}")
|
||||
# Return the original text if parsing fails
|
||||
return StreamingParseResult(normal_text=text)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for Step3 format.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
|
||||
# Build tool indices for validation
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = self._get_tool_indices(tools)
|
||||
|
||||
# If we've finished the tool block, everything is normal text
|
||||
if self._tool_block_finished:
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
# Check if tool block hasn't started yet
|
||||
if not self._in_tool_block:
|
||||
if self.bot_token in self._buffer:
|
||||
idx = self._buffer.find(self.bot_token)
|
||||
normal_text = self._buffer[:idx]
|
||||
self._buffer = self._buffer[idx + len(self.bot_token) :]
|
||||
self._in_tool_block = True
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
else:
|
||||
# Check if we might have a partial bot_token
|
||||
partial_len = self._ends_with_partial_token(
|
||||
self._buffer, self.bot_token
|
||||
)
|
||||
if partial_len:
|
||||
return StreamingParseResult() # Wait for more text
|
||||
else:
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
# We're inside the tool block
|
||||
calls: List[ToolCallItem] = []
|
||||
|
||||
# Check if tool block is ending
|
||||
if self.eot_token in self._buffer:
|
||||
idx = self._buffer.find(self.eot_token)
|
||||
|
||||
# If we're in the middle of a tool call, we need to handle it
|
||||
if self._in_tool_call:
|
||||
# The buffer before eot_token might contain the end of the current tool call
|
||||
before_eot = self._buffer[:idx]
|
||||
if self.tool_call_end in before_eot:
|
||||
# Parse this final tool call
|
||||
result = self._parse_partial_tool_call(tools)
|
||||
calls.extend(result.calls)
|
||||
else:
|
||||
# Incomplete tool call - log warning
|
||||
logger.warning("Tool block ended with incomplete tool call")
|
||||
|
||||
remaining = self._buffer[idx + len(self.eot_token) :]
|
||||
self._buffer = ""
|
||||
self._tool_block_finished = True
|
||||
|
||||
# Reset any partial tool call state
|
||||
self._reset_streaming_state()
|
||||
|
||||
return StreamingParseResult(normal_text=remaining, calls=calls)
|
||||
|
||||
# Check if we're in a tool call or need to start one
|
||||
if not self._in_tool_call:
|
||||
if self.tool_call_begin in self._buffer:
|
||||
idx = self._buffer.find(self.tool_call_begin)
|
||||
# Remove any content before tool call begin (shouldn't happen but be safe)
|
||||
self._buffer = self._buffer[idx + len(self.tool_call_begin) :]
|
||||
self._in_tool_call = True
|
||||
self._function_name_sent = False
|
||||
self._current_function_name = ""
|
||||
self._current_parameters = {}
|
||||
# Fall through to parse the partial tool call
|
||||
else:
|
||||
# Wait for tool call to begin
|
||||
return StreamingParseResult()
|
||||
|
||||
# Parse partial tool call
|
||||
if self._in_tool_call:
|
||||
return self._parse_partial_tool_call(tools)
|
||||
|
||||
return StreamingParseResult()
|
||||
|
||||
def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""Parse partial tool call for streaming scenarios."""
|
||||
calls = []
|
||||
|
||||
# Check if we have tool_sep (means we're past the type declaration)
|
||||
if self.tool_sep not in self._buffer:
|
||||
return StreamingParseResult(calls=calls) # Wait for more text
|
||||
|
||||
type_part, invoke_part = self._buffer.split(self.tool_sep, 1)
|
||||
if type_part.strip() != "function":
|
||||
# Invalid tool type, skip this tool call
|
||||
self._reset_streaming_state()
|
||||
return StreamingParseResult(calls=calls)
|
||||
|
||||
# Try to extract function name if not sent yet
|
||||
if not self._function_name_sent:
|
||||
name_match = re.search(r'<steptml:invoke name="([^"]+)">', invoke_part)
|
||||
if name_match:
|
||||
func_name = name_match.group(1)
|
||||
|
||||
# Validate function name
|
||||
if func_name in self._tool_indices:
|
||||
self._current_function_name = func_name
|
||||
self._function_name_sent = True
|
||||
|
||||
# Initialize tool tracking
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
|
||||
# Ensure tracking arrays are large enough
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
# Store tool call info
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
}
|
||||
|
||||
# Send tool name with empty parameters
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name=func_name,
|
||||
parameters="",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Invalid function name
|
||||
logger.warning(f"Invalid function name: {func_name}")
|
||||
self._reset_streaming_state()
|
||||
return StreamingParseResult(calls=calls)
|
||||
else:
|
||||
# Function name not complete yet
|
||||
return StreamingParseResult(calls=calls)
|
||||
|
||||
# Parse parameters incrementally
|
||||
if self._function_name_sent:
|
||||
# Extract all complete parameters
|
||||
new_params = {}
|
||||
for param_match in self.param_regex.finditer(invoke_part):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
|
||||
# Use schema-aware parsing
|
||||
arg_type = get_argument_type(
|
||||
self._current_function_name, param_name, tools
|
||||
)
|
||||
if arg_type and arg_type != "string":
|
||||
parsed_value, _ = parse_arguments(param_value)
|
||||
new_params[param_name] = parsed_value
|
||||
else:
|
||||
new_params[param_name] = param_value
|
||||
|
||||
# Check if we have new parameters to stream
|
||||
if new_params != self._current_parameters:
|
||||
# Build the JSON content without the closing brace for streaming
|
||||
if not self._current_parameters:
|
||||
# First parameters - send opening brace and content
|
||||
params_content = json.dumps(new_params, ensure_ascii=False)
|
||||
if len(params_content) > 2: # More than just "{}"
|
||||
# Send everything except the closing brace
|
||||
diff = params_content[:-1]
|
||||
else:
|
||||
diff = "{"
|
||||
else:
|
||||
# Subsequent parameters - calculate the incremental diff
|
||||
old_json = json.dumps(self._current_parameters, ensure_ascii=False)
|
||||
new_json = json.dumps(new_params, ensure_ascii=False)
|
||||
|
||||
# Remove closing braces for comparison
|
||||
old_without_brace = old_json[:-1]
|
||||
new_without_brace = new_json[:-1]
|
||||
|
||||
# The new content should extend the old content
|
||||
if new_without_brace.startswith(old_without_brace):
|
||||
diff = new_without_brace[len(old_without_brace) :]
|
||||
else:
|
||||
# Parameters changed in unexpected way - shouldn't happen in normal streaming
|
||||
diff = ""
|
||||
|
||||
if diff:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
parameters=diff,
|
||||
)
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
|
||||
# Update current state
|
||||
self._current_parameters = new_params
|
||||
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
|
||||
|
||||
# Check if tool call is complete
|
||||
if self.tool_call_end in self._buffer:
|
||||
# Send closing brace if we've sent any parameters
|
||||
if self.streamed_args_for_tool[self.current_tool_id]:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
parameters="}",
|
||||
)
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += "}"
|
||||
|
||||
# Find the end position
|
||||
end_idx = self._buffer.find(self.tool_call_end)
|
||||
# Remove the processed tool call from buffer
|
||||
self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]
|
||||
|
||||
# Reset state for next tool call
|
||||
self._reset_streaming_state()
|
||||
self.current_tool_id += 1
|
||||
|
||||
return StreamingParseResult(calls=calls)
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset streaming state for the next tool call"""
|
||||
self._in_tool_call = False
|
||||
self._function_name_sent = False
|
||||
self._current_function_name = ""
|
||||
self._current_parameters = {}
|
||||
|
||||
def supports_structural_tag(self) -> bool:
|
||||
"""Return True if this detector supports structural tag format."""
|
||||
return False
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
raise NotImplementedError()
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||
"""
|
||||
Build EBNF grammar for Step3 tool call format.
|
||||
"""
|
||||
# Custom call rule for steptml format
|
||||
call_rule_fmt = (
|
||||
'"function" "<|tool_sep|>" "<steptml:invoke name=\\"{name}\\">" '
|
||||
'{arguments_rule} "</steptml:invoke>"'
|
||||
)
|
||||
|
||||
# Custom key-value rule for steptml parameters
|
||||
key_value_rule_fmt = (
|
||||
'"<steptml:parameter name=\\"{key}\\">" {valrule} "</steptml:parameter>"'
|
||||
)
|
||||
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
sequence_start_token=self.bot_token,
|
||||
sequence_end_token=self.eot_token,
|
||||
individual_call_start_token=self.tool_call_begin,
|
||||
individual_call_end_token=self.tool_call_end,
|
||||
tool_call_separator="",
|
||||
function_format="xml",
|
||||
call_rule_fmt=call_rule_fmt,
|
||||
key_value_rule_fmt=key_value_rule_fmt,
|
||||
key_value_separator="",
|
||||
)
|
||||
Reference in New Issue
Block a user