support for the DeepSeek model by enabling streaming response parsing (#5592)
This commit is contained in:
@@ -491,6 +491,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
||||
self.eot_token = "<|tool▁calls▁end|>"
|
||||
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
||||
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
||||
self._last_arguments = ""
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a deepseek format tool call."""
|
||||
@@ -528,13 +529,84 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin="<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>"
|
||||
+ name
|
||||
+ "\n```json\n",
|
||||
end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
trigger="<|tool▁calls▁begin|>",
|
||||
begin=">" + name + "\n```json\n",
|
||||
end="\n```<",
|
||||
trigger=">" + name + "\n```json\n",
|
||||
)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
if self.bot_token not in current_text:
|
||||
self._buffer = ""
|
||||
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
||||
if e_token in new_text:
|
||||
new_text = new_text.replace(e_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function and tool.function.name
|
||||
}
|
||||
|
||||
calls: list[ToolCallItem] = []
|
||||
try:
|
||||
partial_match = re.search(
|
||||
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
||||
string=current_text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
if partial_match:
|
||||
func_name = partial_match.group(2).strip()
|
||||
func_args_raw = partial_match.group(3).strip()
|
||||
|
||||
if not self.current_tool_name_sent:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices.get(func_name, 0),
|
||||
name=func_name,
|
||||
parameters="",
|
||||
)
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
argument_diff = (
|
||||
func_args_raw[len(self._last_arguments) :]
|
||||
if func_args_raw.startswith(self._last_arguments)
|
||||
else func_args_raw
|
||||
)
|
||||
|
||||
if argument_diff:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self._tool_indices.get(func_name, 0),
|
||||
name=None,
|
||||
parameters=argument_diff,
|
||||
)
|
||||
)
|
||||
self._last_arguments += argument_diff
|
||||
|
||||
if _is_complete_json(func_args_raw):
|
||||
result = StreamingParseResult(normal_text="", calls=calls)
|
||||
self._buffer = ""
|
||||
self._last_arguments = ""
|
||||
self.current_tool_name_sent = False
|
||||
return result
|
||||
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult(normal_text=current_text)
|
||||
|
||||
|
||||
class MultiFormatParser:
|
||||
def __init__(self, detectors: List[BaseFormatDetector]):
|
||||
|
||||
@@ -966,8 +966,6 @@ def v1_chat_generate_request(
|
||||
),
|
||||
}
|
||||
)
|
||||
# TODO fix the compatible issues with xgrammar
|
||||
strict_tag = None
|
||||
|
||||
for message in request.messages:
|
||||
if isinstance(message.content, str):
|
||||
|
||||
Reference in New Issue
Block a user