Constraint Decoding: Tool call with text (#4067)
This commit is contained in:
@@ -128,13 +128,15 @@ class BaseFormatDetector:
|
||||
|
||||
return results
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
def detect_and_parse(
|
||||
self, text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
"""
|
||||
action = json.loads(text)
|
||||
return self.parse_base_json(action, tools)
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Function]
|
||||
@@ -322,7 +324,9 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
def detect_and_parse(
|
||||
self, text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
@@ -330,15 +334,17 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
if "<tool_call>" not in text:
|
||||
return []
|
||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
calls = []
|
||||
for match_result in match_result_list:
|
||||
match_result = json.loads(match_result)
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return calls
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
@@ -374,7 +380,9 @@ class MistralDetector(BaseFormatDetector):
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
def detect_and_parse(
|
||||
self, text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
@@ -382,6 +390,8 @@ class MistralDetector(BaseFormatDetector):
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
text = self._clean_text(text)
|
||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||
@@ -391,7 +401,7 @@ class MistralDetector(BaseFormatDetector):
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
for match_result in function_call_arr:
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return calls
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
|
||||
class Llama32Detector(BaseFormatDetector):
|
||||
@@ -414,7 +424,7 @@ class Llama32Detector(BaseFormatDetector):
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""Parse function calls from text, handling multiple JSON objects."""
|
||||
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||
return []
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
if "<|python_tag|>" in text:
|
||||
_, action_text = text.split("<|python_tag|>")
|
||||
@@ -423,7 +433,6 @@ class Llama32Detector(BaseFormatDetector):
|
||||
|
||||
# Split by semicolon and process each part
|
||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||
|
||||
all_actions = []
|
||||
for part in json_parts:
|
||||
try:
|
||||
@@ -434,12 +443,11 @@ class Llama32Detector(BaseFormatDetector):
|
||||
logger.warning(f"Failed to parse JSON part: {part}")
|
||||
logger.warning(f"JSON parse error: {str(e)}")
|
||||
continue
|
||||
|
||||
calls = []
|
||||
# Only process if we found valid JSON objects
|
||||
if all_actions:
|
||||
return self.parse_base_json(all_actions, tools)
|
||||
|
||||
return []
|
||||
calls = self.parse_base_json(all_actions, tools)
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
|
||||
class MultiFormatParser:
|
||||
@@ -449,7 +457,9 @@ class MultiFormatParser:
|
||||
"""
|
||||
self.detectors = detectors
|
||||
|
||||
def parse_once(self, text: str, tools: List[Function]):
|
||||
def parse_once(
|
||||
self, text: str, tools: List[Function]
|
||||
) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||
Return: (final_text, all_calls)
|
||||
@@ -459,15 +469,19 @@ class MultiFormatParser:
|
||||
final_calls = []
|
||||
final_normal_text = text
|
||||
for detector in self.detectors:
|
||||
tool_call_list = detector.detect_and_parse(text, tools)
|
||||
parsed_result = detector.detect_and_parse(text, tools)
|
||||
tool_call_list = parsed_result.calls
|
||||
if len(tool_call_list) > 0: # parsed successfully
|
||||
final_calls = tool_call_list
|
||||
final_normal_text = parsed_result.normal_text
|
||||
break
|
||||
|
||||
# leftover_text is the normal text not consumed by any Detector
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Function]
|
||||
) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||
and merge their produced normal_text/calls to return.
|
||||
@@ -532,7 +546,7 @@ class FunctionCallParser:
|
||||
return True
|
||||
return False
|
||||
|
||||
def parse_non_stream(self, full_text: str):
|
||||
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Non-streaming call: one-time parsing
|
||||
"""
|
||||
@@ -541,7 +555,7 @@ class FunctionCallParser:
|
||||
)
|
||||
return full_normal_text, calls
|
||||
|
||||
def parse_stream_chunk(self, chunk_text: str):
|
||||
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Streaming call: incremental parsing
|
||||
"""
|
||||
|
||||
@@ -1130,7 +1130,7 @@ def v1_chat_generate_response(
|
||||
finish_reason["type"] = "tool_calls"
|
||||
finish_reason["matched"] = None
|
||||
try:
|
||||
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(call_info.tool_index),
|
||||
@@ -1153,9 +1153,9 @@ def v1_chat_generate_response(
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text if tool_calls is None else None,
|
||||
"content": text if text else None,
|
||||
"tool_calls": tool_calls,
|
||||
"reasoning_content": reasoning_text,
|
||||
"reasoning_content": reasoning_text if reasoning_text else None,
|
||||
},
|
||||
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
||||
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
||||
@@ -1170,9 +1170,9 @@ def v1_chat_generate_response(
|
||||
index=idx,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=text if tool_calls is None else None,
|
||||
content=text if text else None,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=reasoning_text,
|
||||
reasoning_content=reasoning_text if reasoning_text else None,
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||
@@ -1317,9 +1317,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
):
|
||||
delta = DeltaMessage(role="assistant", reasoning_content="")
|
||||
delta = DeltaMessage(
|
||||
role="assistant", reasoning_content=None
|
||||
)
|
||||
else:
|
||||
delta = DeltaMessage(role="assistant", content="")
|
||||
delta = DeltaMessage(role="assistant", content=None)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=delta,
|
||||
@@ -1362,7 +1364,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
if reasoning_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(reasoning_content=reasoning_text),
|
||||
delta=DeltaMessage(
|
||||
reasoning_content=(
|
||||
reasoning_text if reasoning_text else None
|
||||
)
|
||||
),
|
||||
finish_reason=(
|
||||
None
|
||||
if finish_reason_type
|
||||
@@ -1396,7 +1402,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
if normal_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=normal_text),
|
||||
delta=DeltaMessage(
|
||||
content=normal_text if normal_text else None
|
||||
),
|
||||
finish_reason=(
|
||||
None
|
||||
if finish_reason_type
|
||||
@@ -1468,7 +1476,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
# No tool calls => just treat this as normal text
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
delta=DeltaMessage(content=delta if delta else None),
|
||||
finish_reason=(
|
||||
None
|
||||
if finish_reason_type and len(finish_reason_type) == 0
|
||||
|
||||
Reference in New Issue
Block a user