diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 98cb5937e..99d11f50a 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -41,7 +41,7 @@ "\n", "\n", "server_process, port = launch_server_cmd(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --host 0.0.0.0\" # llama3\n", + " \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n", ")\n", "wait_for_server(f\"http://localhost:{port}\")" ] @@ -55,7 +55,7 @@ "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", - "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)." + "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)." ] }, { @@ -121,7 +121,7 @@ " return [\n", " {\n", " \"role\": \"user\",\n", - " \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n", + " \"content\": \"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\",\n", " }\n", " ]\n", "\n", @@ -164,63 +164,28 @@ "response_non_stream = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", - " temperature=0.8,\n", - " top_p=0.8,\n", + " temperature=0.1,\n", + " top_p=0.95,\n", + " max_tokens=1024,\n", " stream=False, # Non-streaming\n", " tools=tools,\n", ")\n", "print_highlight(\"Non-stream response:\")\n", - "print(response_non_stream)" + "print(response_non_stream)\n", + "print_highlight(\"==== content ====\")\n", + "print(response_non_stream.choices[0].message.content)\n", + "print_highlight(\"==== tool_calls ====\")\n", + "print(response_non_stream.choices[0].message.tool_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Streaming Request" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Streaming mode test\n", - "print_highlight(\"Streaming response:\")\n", - "response_stream = client.chat.completions.create(\n", - " model=model_name,\n", - " messages=messages,\n", - " temperature=0.8,\n", - " top_p=0.8,\n", - " stream=True, # Enable streaming\n", - " tools=tools,\n", - ")\n", - "\n", - "chunks = []\n", - "for chunk in response_stream:\n", - " chunks.append(chunk)\n", - " if chunk.choices[0].delta.tool_calls:\n", - " print(chunk.choices[0].delta.tool_calls[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "### Handle Tool Calls\n", - "\n", + "#### Handle Tools\n", "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly." ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Non-Streaming Request**" - ] - }, { "cell_type": "code", "execution_count": null, @@ -240,7 +205,50 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Streaming Request**" + "### Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming mode test\n", + "print_highlight(\"Streaming response:\")\n", + "response_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.1,\n", + " top_p=0.95,\n", + " max_tokens=1024,\n", + " stream=True, # Enable streaming\n", + " tools=tools,\n", + ")\n", + "\n", + "texts = \"\"\n", + "tool_calls = []\n", + "name = \"\"\n", + "arguments = \"\"\n", + "for chunk in response_stream:\n", + " if chunk.choices[0].delta.content:\n", + " texts += chunk.choices[0].delta.content\n", + " if chunk.choices[0].delta.tool_calls:\n", + " tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n", + "print_highlight(\"==== Text ====\")\n", + "print(texts)\n", + "\n", + "print_highlight(\"==== Tool Call ====\")\n", + "for tool_call in tool_calls:\n", + " print(tool_call)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Handle Tools\n", + "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly." ] }, { @@ -251,21 +259,16 @@ "source": [ "# Parse and combine function call arguments\n", "arguments = []\n", - "for chunk in chunks:\n", - " choice = chunk.choices[0]\n", - " delta = choice.delta\n", - " if delta.tool_calls:\n", - " tool_call = delta.tool_calls[0]\n", - " if tool_call.function.name:\n", - " print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n", + "for tool_call in tool_calls:\n", + " if tool_call.function.name:\n", + " print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n", "\n", - " if tool_call.function.arguments:\n", - " arguments.append(tool_call.function.arguments)\n", - " print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n", + " if tool_call.function.arguments:\n", + " arguments.append(tool_call.function.arguments)\n", "\n", "# Combine all fragments into a single JSON string\n", "full_arguments = \"\".join(arguments)\n", - "print_highlight(f\"Final streamed function call arguments: {full_arguments}\")" + "print_highlight(f\"streamed function call arguments: {full_arguments}\")" ] }, { @@ -342,13 +345,16 @@ "final_response = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", - " temperature=0.8,\n", - " top_p=0.8,\n", + " temperature=0.1,\n", + " top_p=0.95,\n", " stream=False,\n", " tools=tools,\n", ")\n", "print_highlight(\"Non-stream response:\")\n", - "print(final_response)" + "print(final_response)\n", + "\n", + "print_highlight(\"==== Text ====\")\n", + "print(final_response.choices[0].message.content)" ] }, { @@ -368,7 +374,7 @@ "import requests\n", "\n", "# generate an answer\n", - "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n", "\n", "messages = get_messages()\n", "\n", @@ -380,8 +386,17 @@ ")\n", "\n", "gen_url = f\"http://localhost:{port}/generate\"\n", - "gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n", + "gen_data = {\n", + " \"text\": input,\n", + " \"sampling_params\": {\n", + " \"skip_special_tokens\": False,\n", + " \"max_new_tokens\": 1024,\n", + " \"temperature\": 0.1,\n", + " \"top_p\": 0.95,\n", + " },\n", + "}\n", "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", + "print_highlight(\"==== Reponse ====\")\n", "print(gen_response)\n", "\n", "# parse the response\n", @@ -389,12 +404,16 @@ "\n", "function_call_input = {\n", " \"text\": gen_response,\n", - " \"tool_call_parser\": \"llama3\",\n", + " \"tool_call_parser\": \"qwen25\",\n", " \"tools\": tools,\n", "}\n", "\n", "function_call_response = requests.post(parse_url, json=function_call_input)\n", "function_call_response_json = function_call_response.json()\n", + "\n", + "print_highlight(\"==== Text ====\")\n", + "print(function_call_response_json[\"normal_text\"])\n", + "print_highlight(\"==== Calls ====\")\n", "print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n", "print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])" ] @@ -425,15 +444,15 @@ "from sglang.srt.function_call_parser import FunctionCallParser\n", "from sglang.srt.managers.io_struct import Tool, Function\n", "\n", - "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", + "llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n", "tokenizer = llm.tokenizer_manager.tokenizer\n", "input_ids = tokenizer.apply_chat_template(\n", " messages, tokenize=True, add_generation_prompt=True, tools=tools\n", ")\n", "\n", "sampling_params = {\n", - " \"max_new_tokens\": 128,\n", - " \"temperature\": 0.3,\n", + " \"max_new_tokens\": 1024,\n", + " \"temperature\": 0.1,\n", " \"top_p\": 0.95,\n", " \"skip_special_tokens\": False,\n", "}\n", @@ -461,10 +480,10 @@ "\n", "tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n", "\n", - "parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n", + "parser = FunctionCallParser(tools=tools, tool_call_parser=\"qwen25\")\n", "normal_text, calls = parser.parse_non_stream(generated_text)\n", "\n", - "print(\"\\n=== Parsing Result ===\")\n", + "print(\"=== Parsing Result ===\")\n", "print(\"Normal text portion:\", normal_text)\n", "print(\"Function call portion:\")\n", "for call in calls:\n", @@ -521,5 +540,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py index 7754ba58c..4ae8d0a0d 100644 --- a/python/sglang/srt/function_call_parser.py +++ b/python/sglang/srt/function_call_parser.py @@ -128,13 +128,15 @@ class BaseFormatDetector: return results - def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + def detect_and_parse( + self, text: str, tools: List[Function] + ) -> StreamingParseResult: """ Parses the text in one go. Returns success=True if the format matches, otherwise False. Note that leftover_text here represents "content that this parser will not consume further". """ action = json.loads(text) - return self.parse_base_json(action, tools) + return StreamingParseResult(calls=self.parse_base_json(action, tools)) def parse_streaming_increment( self, new_text: str, tools: List[Function] @@ -322,7 +324,9 @@ class Qwen25Detector(BaseFormatDetector): """Check if the text contains a Qwen 2.5 format tool call.""" return self.bot_token in text - def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + def detect_and_parse( + self, text: str, tools: List[Function] + ) -> StreamingParseResult: """ One-time parsing: Detects and parses tool calls in the provided text. @@ -330,15 +334,17 @@ class Qwen25Detector(BaseFormatDetector): :param tools: List of available tools. :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. """ - if "" not in text: - return [] - pattern = r"(.*?)" + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + pattern = rf"{self.bot_token}(.*?){self.eot_token}" match_result_list = re.findall(pattern, text, re.DOTALL) calls = [] for match_result in match_result_list: match_result = json.loads(match_result) calls.extend(self.parse_base_json(match_result, tools)) - return calls + return StreamingParseResult(normal_text=normal_text, calls=calls) class MistralDetector(BaseFormatDetector): @@ -374,7 +380,9 @@ class MistralDetector(BaseFormatDetector): else: return "" - def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + def detect_and_parse( + self, text: str, tools: List[Function] + ) -> StreamingParseResult: """ One-time parsing: Detects and parses tool calls in the provided text. @@ -382,6 +390,8 @@ class MistralDetector(BaseFormatDetector): :param tools: List of available tools. :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text text = self._clean_text(text) tool_content = text.replace("[TOOL_CALLS]", "").strip() raw_tool_calls = self.tool_call_regex.findall(tool_content) @@ -391,7 +401,7 @@ class MistralDetector(BaseFormatDetector): function_call_arr = json.loads(raw_tool_call) for match_result in function_call_arr: calls.extend(self.parse_base_json(match_result, tools)) - return calls + return StreamingParseResult(normal_text=normal_text, calls=calls) class Llama32Detector(BaseFormatDetector): @@ -414,7 +424,7 @@ class Llama32Detector(BaseFormatDetector): def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: """Parse function calls from text, handling multiple JSON objects.""" if "<|python_tag|>" not in text and not text.startswith("{"): - return [] + return StreamingParseResult(normal_text=text, calls=[]) if "<|python_tag|>" in text: _, action_text = text.split("<|python_tag|>") @@ -423,7 +433,6 @@ class Llama32Detector(BaseFormatDetector): # Split by semicolon and process each part json_parts = [part.strip() for part in action_text.split(";") if part.strip()] - all_actions = [] for part in json_parts: try: @@ -434,12 +443,11 @@ class Llama32Detector(BaseFormatDetector): logger.warning(f"Failed to parse JSON part: {part}") logger.warning(f"JSON parse error: {str(e)}") continue - + calls = [] # Only process if we found valid JSON objects if all_actions: - return self.parse_base_json(all_actions, tools) - - return [] + calls = self.parse_base_json(all_actions, tools) + return StreamingParseResult(normal_text=normal_text, calls=calls) class MultiFormatParser: @@ -449,7 +457,9 @@ class MultiFormatParser: """ self.detectors = detectors - def parse_once(self, text: str, tools: List[Function]): + def parse_once( + self, text: str, tools: List[Function] + ) -> Tuple[str, list[ToolCallItem]]: """ One-time parsing: Loop through detectors until there are no new matches or text is exhausted Return: (final_text, all_calls) @@ -459,15 +469,19 @@ class MultiFormatParser: final_calls = [] final_normal_text = text for detector in self.detectors: - tool_call_list = detector.detect_and_parse(text, tools) + parsed_result = detector.detect_and_parse(text, tools) + tool_call_list = parsed_result.calls if len(tool_call_list) > 0: # parsed successfully final_calls = tool_call_list + final_normal_text = parsed_result.normal_text break # leftover_text is the normal text not consumed by any Detector return final_normal_text, final_calls - def parse_streaming_increment(self, new_text: str, tools: List[Function]): + def parse_streaming_increment( + self, new_text: str, tools: List[Function] + ) -> Tuple[str, list[ToolCallItem]]: """ Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment and merge their produced normal_text/calls to return. @@ -532,7 +546,7 @@ class FunctionCallParser: return True return False - def parse_non_stream(self, full_text: str): + def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]: """ Non-streaming call: one-time parsing """ @@ -541,7 +555,7 @@ class FunctionCallParser: ) return full_normal_text, calls - def parse_stream_chunk(self, chunk_text: str): + def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]: """ Streaming call: incremental parsing """ diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index ba4ca7b4f..a9f1124ac 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -1130,7 +1130,7 @@ def v1_chat_generate_response( finish_reason["type"] = "tool_calls" finish_reason["matched"] = None try: - full_normal_text, call_info_list = parser.parse_non_stream(text) + text, call_info_list = parser.parse_non_stream(text) tool_calls = [ ToolCall( id=str(call_info.tool_index), @@ -1153,9 +1153,9 @@ def v1_chat_generate_response( "index": 0, "message": { "role": "assistant", - "content": text if tool_calls is None else None, + "content": text if text else None, "tool_calls": tool_calls, - "reasoning_content": reasoning_text, + "reasoning_content": reasoning_text if reasoning_text else None, }, "logprobs": choice_logprobs.model_dump() if choice_logprobs else None, "finish_reason": (finish_reason["type"] if finish_reason else ""), @@ -1170,9 +1170,9 @@ def v1_chat_generate_response( index=idx, message=ChatMessage( role="assistant", - content=text if tool_calls is None else None, + content=text if text else None, tool_calls=tool_calls, - reasoning_content=reasoning_text, + reasoning_content=reasoning_text if reasoning_text else None, ), logprobs=choice_logprobs, finish_reason=(finish_reason["type"] if finish_reason else ""), @@ -1317,9 +1317,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): tokenizer_manager.server_args.reasoning_parser and request.separate_reasoning ): - delta = DeltaMessage(role="assistant", reasoning_content="") + delta = DeltaMessage( + role="assistant", reasoning_content=None + ) else: - delta = DeltaMessage(role="assistant", content="") + delta = DeltaMessage(role="assistant", content=None) choice_data = ChatCompletionResponseStreamChoice( index=index, delta=delta, @@ -1362,7 +1364,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if reasoning_text: choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(reasoning_content=reasoning_text), + delta=DeltaMessage( + reasoning_content=( + reasoning_text if reasoning_text else None + ) + ), finish_reason=( None if finish_reason_type @@ -1396,7 +1402,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if normal_text: choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(content=normal_text), + delta=DeltaMessage( + content=normal_text if normal_text else None + ), finish_reason=( None if finish_reason_type @@ -1468,7 +1476,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): # No tool calls => just treat this as normal text choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(content=delta), + delta=DeltaMessage(content=delta if delta else None), finish_reason=( None if finish_reason_type and len(finish_reason_type) == 0 diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 23e028729..e9adf617f 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -257,7 +257,7 @@ class TestOpenAIServer(unittest.TestCase): ret_num_top_logprobs == logprobs ), f"{ret_num_top_logprobs} vs {logprobs}" - assert isinstance(data.content, str) + assert isinstance(data.content, str) or response.choices[0].finish_reason assert response.id assert response.created