Fix: Runtime error for function calling (#3300)

This commit is contained in:
Shi Shuai
2025-02-07 04:52:01 +00:00
committed by GitHub
parent 40022d075a
commit 591e751e07
2 changed files with 97 additions and 70 deletions

View File

@@ -20,7 +20,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker). - [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker).
## Optimisations ## Optimizations
### Multi-head Latent Attention (MLA) Throughput Optimizations ### Multi-head Latent Attention (MLA) Throughput Optimizations

View File

@@ -1,4 +1,5 @@
import json import json
import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
@@ -8,6 +9,8 @@ import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
TOOLS_TAG_LIST = [ TOOLS_TAG_LIST = [
"<|plugin|>", "<|plugin|>",
"<function=", "<function=",
@@ -88,17 +91,43 @@ class BaseFormatDetector:
self.bot_token = "" self.bot_token = ""
self.eot_token = "" self.eot_token = ""
def parse_base_json(self, action: Dict, tools: List[Function]): def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
name, parameters = action["name"], json.dumps( tool_indices = {
action.get("parameters", action.get("arguments", {})), tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
name = action.get("name")
if not name or name not in tool_indices:
logger.warning(f"Model attempted to call undefined function: {name}")
return []
return [
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
action.get("parameters") or action.get("arguments", {}),
ensure_ascii=False, ensure_ascii=False,
),
) )
tool_index = [tool.function.name for tool in tools].index(name) ]
tool_call_item = ToolCallItem(
tool_index=tool_index, name=name, parameters=parameters results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
) )
calls = [tool_call_item] )
return calls
return results
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
""" """
@@ -112,9 +141,7 @@ class BaseFormatDetector:
self, new_text: str, tools: List[Function] self, new_text: str, tools: List[Function]
) -> StreamingParseResult: ) -> StreamingParseResult:
""" """
Streaming incremental parsing, referencing the logic of Llama32Detector. Streaming incremental parsing with tool validation.
We partially parse JSON within <tool_call>...</tool_call>, and handle
incremental argument output.
""" """
# Append new text to buffer # Append new text to buffer
self._buffer += new_text self._buffer += new_text
@@ -125,17 +152,19 @@ class BaseFormatDetector:
new_text = new_text.replace(self.eot_token, "") new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text) return StreamingParseResult(normal_text=new_text)
# bit mask flags for partial JSON parsing. If the name hasn't been # Build tool indices if not already built
# sent yet, don't allow sending if not hasattr(self, "_tool_indices"):
# an incomplete string since OpenAI only ever (as far as I have self._tool_indices = {
# seen) allows sending the entire tool/ function name at once. tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try: try:
tool_call_arr = [] tool_call_arr = []
is_complete = [] is_complete = []
try: try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = ( start_idx = (
len(self.bot_token) len(self.bot_token)
if current_text.startswith(self.bot_token) if current_text.startswith(self.bot_token)
@@ -149,8 +178,18 @@ class BaseFormatDetector:
_is_complete_json(current_text[start_idx : start_idx + end_idx]) _is_complete_json(current_text[start_idx : start_idx + end_idx])
) )
start_idx += end_idx + len("; ") start_idx += end_idx + len("; ")
# depending on the prompt Llama can use
# either arguments or parameters # Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj: if "parameters" in obj:
assert ( assert (
"arguments" not in obj "arguments" not in obj
@@ -159,29 +198,17 @@ class BaseFormatDetector:
tool_call_arr.append(obj) tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON: except partial_json_parser.core.exceptions.MalformedJSON:
# not enough tokens to parse into JSON yet
return StreamingParseResult() return StreamingParseResult()
# select as the current tool call the one we're on the state at if len(tool_call_arr) == 0:
return StreamingParseResult()
current_tool_call: Dict = ( current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
) )
# case -- if no tokens have been streamed for the tool, e.g. # Handle new tool in array
# only the array brackets, stream nothing if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if len(tool_call_arr) == 0:
return StreamingParseResult()
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0: if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
if cur_arguments: if cur_arguments:
@@ -190,7 +217,6 @@ class BaseFormatDetector:
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
res = StreamingParseResult( res = StreamingParseResult(
normal_text=None,
calls=[ calls=[
ToolCallItem( ToolCallItem(
tool_index=self.current_tool_id, tool_index=self.current_tool_id,
@@ -206,23 +232,20 @@ class BaseFormatDetector:
res = StreamingParseResult() res = StreamingParseResult()
else: else:
res = StreamingParseResult() res = StreamingParseResult()
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False self.current_tool_name_sent = False
self.streamed_args_for_tool.append("") self.streamed_args_for_tool.append("")
print("starting on new tool %d", self.current_tool_id)
return res return res
# if the current tool name hasn't been sent, send if available # Handle tool name
# - otherwise send nothing
elif not self.current_tool_name_sent: elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name") function_name = current_tool_call.get("name")
if function_name: if function_name and function_name in self._tool_indices:
res = StreamingParseResult( res = StreamingParseResult(
normal_text=None,
calls=[ calls=[
ToolCallItem( ToolCallItem(
tool_index=self.current_tool_id, tool_index=self._tool_indices[function_name],
name=function_name, name=function_name,
parameters="", parameters="",
) )
@@ -232,8 +255,7 @@ class BaseFormatDetector:
else: else:
res = StreamingParseResult() res = StreamingParseResult()
# now we know we're on the same tool call and we're streaming # Handle streaming arguments
# arguments
else: else:
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult() res = StreamingParseResult()
@@ -250,13 +272,12 @@ class BaseFormatDetector:
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
self._buffer = "" self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear() self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent: bool = False self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = "" self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments: elif prev_arguments:
prev_args_json = json.dumps(prev_arguments) prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json: if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json) prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:] argument_diff = prefix[sent:]
@@ -279,8 +300,7 @@ class BaseFormatDetector:
return res return res
except Exception as e: except Exception as e:
print(e) logger.error(f"Error in parse_streaming_increment: {e}")
# Skipping chunk as a result of tool streaming extraction error
return StreamingParseResult() return StreamingParseResult()
@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
Detector for Llama 3.2 models. Detector for Llama 3.2 models.
Assumes function call format: Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}} <|python_tag|>{"name":"xxx", "arguments":{...}}
Does not require a closing tag "</python_tag|>",
relies on json.loads(...) success to determine if JSON is complete.
""" """
def __init__(self): def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__() super().__init__()
self.bot_token = "<|python_tag|>" self.bot_token = "<|python_tag|>"
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
""" """Parse function calls from text, handling multiple JSON objects."""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if "<|python_tag|>" not in text: if "<|python_tag|>" not in text:
return [] return []
_, action = text.split("<|python_tag|>")
action = json.loads(action) _, action_text = text.split("<|python_tag|>")
return self.parse_base_json(action, tools)
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
# Only process if we found valid JSON objects
if all_actions:
return self.parse_base_json(all_actions, tools)
return []
class MultiFormatParser: class MultiFormatParser: