diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 04e2384d9..dd29ad1cc 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -20,7 +20,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker). -## Optimisations +## Optimizations ### Multi-head Latent Attention (MLA) Throughput Optimizations diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py index 3def4e1eb..62ae53e2d 100644 --- a/python/sglang/srt/function_call_parser.py +++ b/python/sglang/srt/function_call_parser.py @@ -1,4 +1,5 @@ import json +import logging import re from abc import ABC, abstractmethod from json import JSONDecodeError, JSONDecoder @@ -8,6 +9,8 @@ import partial_json_parser from partial_json_parser.core.options import Allow from pydantic import BaseModel, Field +logger = logging.getLogger(__name__) + TOOLS_TAG_LIST = [ "<|plugin|>", " List[ToolCallItem]: + tool_indices = { + tool.function.name: i for i, tool in enumerate(tools) if tool.function.name + } + if not isinstance(action, list): + name = action.get("name") + if not name or name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {name}") + return [] + + return [ + ToolCallItem( + tool_index=tool_indices[name], + name=name, + parameters=json.dumps( + action.get("parameters") or action.get("arguments", {}), + ensure_ascii=False, + ), + ) + ] + + results = [] + for act in action: + name = act.get("name") + if name and name in tool_indices: + results.append( + ToolCallItem( + tool_index=tool_indices[name], + name=name, + parameters=json.dumps( + act.get("parameters") or act.get("arguments", {}), + ensure_ascii=False, + ), + ) + ) + + return results def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: """ @@ -112,9 +141,7 @@ class BaseFormatDetector: self, new_text: str, tools: List[Function] ) -> StreamingParseResult: """ - Streaming incremental parsing, referencing the logic of Llama32Detector. - We partially parse JSON within ..., and handle - incremental argument output. + Streaming incremental parsing with tool validation. """ # Append new text to buffer self._buffer += new_text @@ -125,17 +152,19 @@ class BaseFormatDetector: new_text = new_text.replace(self.eot_token, "") return StreamingParseResult(normal_text=new_text) - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. + # Build tool indices if not already built + if not hasattr(self, "_tool_indices"): + self._tool_indices = { + tool.function.name: i + for i, tool in enumerate(tools) + if tool.function and tool.function.name + } + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = [] is_complete = [] try: - # depending on the prompt format the Llama model may or may not - # prefix the output with the <|python_tag|> token start_idx = ( len(self.bot_token) if current_text.startswith(self.bot_token) @@ -149,8 +178,18 @@ class BaseFormatDetector: _is_complete_json(current_text[start_idx : start_idx + end_idx]) ) start_idx += end_idx + len("; ") - # depending on the prompt Llama can use - # either arguments or parameters + + # Validate tool name if present + if "name" in obj and obj["name"] not in self._tool_indices: + # Invalid tool name - reset state + self._buffer = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + if self.streamed_args_for_tool: + self.streamed_args_for_tool.pop() + return StreamingParseResult() + + # Handle parameters/arguments consistency if "parameters" in obj: assert ( "arguments" not in obj @@ -159,29 +198,17 @@ class BaseFormatDetector: tool_call_arr.append(obj) except partial_json_parser.core.exceptions.MalformedJSON: - # not enough tokens to parse into JSON yet return StreamingParseResult() - # select as the current tool call the one we're on the state at + if len(tool_call_arr) == 0: + return StreamingParseResult() + current_tool_call: Dict = ( tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} ) - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return StreamingParseResult() - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif ( - len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 - ): - - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. + # Handle new tool in array + if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: @@ -190,7 +217,6 @@ class BaseFormatDetector: argument_diff = cur_args_json[sent:] res = StreamingParseResult( - normal_text=None, calls=[ ToolCallItem( tool_index=self.current_tool_id, @@ -206,23 +232,20 @@ class BaseFormatDetector: res = StreamingParseResult() else: res = StreamingParseResult() - # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_name_sent = False self.streamed_args_for_tool.append("") - print("starting on new tool %d", self.current_tool_id) return res - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing + # Handle tool name elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") - if function_name: + if function_name and function_name in self._tool_indices: res = StreamingParseResult( - normal_text=None, calls=[ ToolCallItem( - tool_index=self.current_tool_id, + tool_index=self._tool_indices[function_name], name=function_name, parameters="", ) @@ -232,8 +255,7 @@ class BaseFormatDetector: else: res = StreamingParseResult() - # now we know we're on the same tool call and we're streaming - # arguments + # Handle streaming arguments else: cur_arguments = current_tool_call.get("arguments") res = StreamingParseResult() @@ -250,13 +272,12 @@ class BaseFormatDetector: argument_diff = cur_args_json[sent:] self._buffer = "" self.prev_tool_call_arr[self.current_tool_id].clear() - self.current_tool_name_sent: bool = False + self.current_tool_name_sent = False self.streamed_args_for_tool[self.current_tool_id] = "" elif prev_arguments: prev_args_json = json.dumps(prev_arguments) if cur_args_json != prev_args_json: - prefix = _find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] @@ -279,8 +300,7 @@ class BaseFormatDetector: return res except Exception as e: - print(e) - # Skipping chunk as a result of tool streaming extraction error + logger.error(f"Error in parse_streaming_increment: {e}") return StreamingParseResult() @@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector): Detector for Llama 3.2 models. Assumes function call format: <|python_tag|>{"name":"xxx", "arguments":{...}} - Does not require a closing tag "", - relies on json.loads(...) success to determine if JSON is complete. """ def __init__(self): - """ - Initializes the detector with necessary state variables. - """ super().__init__() self.bot_token = "<|python_tag|>" def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: - """ - One-time parsing: Detects and parses tool calls in the provided text. - - :param text: The complete text to parse. - :param tools: List of available tools. - :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. - """ - + """Parse function calls from text, handling multiple JSON objects.""" if "<|python_tag|>" not in text: return [] - _, action = text.split("<|python_tag|>") - action = json.loads(action) - return self.parse_base_json(action, tools) + + _, action_text = text.split("<|python_tag|>") + + # Split by semicolon and process each part + json_parts = [part.strip() for part in action_text.split(";") if part.strip()] + + all_actions = [] + for part in json_parts: + try: + # Parse each individual JSON object + action = json.loads(part) + all_actions.append(action) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse JSON part: {part}") + logger.warning(f"JSON parse error: {str(e)}") + continue + + # Only process if we found valid JSON objects + if all_actions: + return self.parse_base_json(all_actions, tools) + + return [] class MultiFormatParser: