Constraint Decoding: Tool call with text (#4067)
This commit is contained in:
@@ -41,7 +41,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"server_process, port = launch_server_cmd(\n",
|
"server_process, port = launch_server_cmd(\n",
|
||||||
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --host 0.0.0.0\" # llama3\n",
|
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n",
|
||||||
")\n",
|
")\n",
|
||||||
"wait_for_server(f\"http://localhost:{port}\")"
|
"wait_for_server(f\"http://localhost:{port}\")"
|
||||||
]
|
]
|
||||||
@@ -55,7 +55,7 @@
|
|||||||
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
||||||
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||||||
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)."
|
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -121,7 +121,7 @@
|
|||||||
" return [\n",
|
" return [\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"role\": \"user\",\n",
|
" \"role\": \"user\",\n",
|
||||||
" \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n",
|
" \"content\": \"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\",\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" ]\n",
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -164,63 +164,28 @@
|
|||||||
"response_non_stream = client.chat.completions.create(\n",
|
"response_non_stream = client.chat.completions.create(\n",
|
||||||
" model=model_name,\n",
|
" model=model_name,\n",
|
||||||
" messages=messages,\n",
|
" messages=messages,\n",
|
||||||
" temperature=0.8,\n",
|
" temperature=0.1,\n",
|
||||||
" top_p=0.8,\n",
|
" top_p=0.95,\n",
|
||||||
|
" max_tokens=1024,\n",
|
||||||
" stream=False, # Non-streaming\n",
|
" stream=False, # Non-streaming\n",
|
||||||
" tools=tools,\n",
|
" tools=tools,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"print_highlight(\"Non-stream response:\")\n",
|
"print_highlight(\"Non-stream response:\")\n",
|
||||||
"print(response_non_stream)"
|
"print(response_non_stream)\n",
|
||||||
|
"print_highlight(\"==== content ====\")\n",
|
||||||
|
"print(response_non_stream.choices[0].message.content)\n",
|
||||||
|
"print_highlight(\"==== tool_calls ====\")\n",
|
||||||
|
"print(response_non_stream.choices[0].message.tool_calls)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Streaming Request"
|
"#### Handle Tools\n",
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Streaming mode test\n",
|
|
||||||
"print_highlight(\"Streaming response:\")\n",
|
|
||||||
"response_stream = client.chat.completions.create(\n",
|
|
||||||
" model=model_name,\n",
|
|
||||||
" messages=messages,\n",
|
|
||||||
" temperature=0.8,\n",
|
|
||||||
" top_p=0.8,\n",
|
|
||||||
" stream=True, # Enable streaming\n",
|
|
||||||
" tools=tools,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"chunks = []\n",
|
|
||||||
"for chunk in response_stream:\n",
|
|
||||||
" chunks.append(chunk)\n",
|
|
||||||
" if chunk.choices[0].delta.tool_calls:\n",
|
|
||||||
" print(chunk.choices[0].delta.tool_calls[0])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"### Handle Tool Calls\n",
|
|
||||||
"\n",
|
|
||||||
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Non-Streaming Request**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@@ -240,7 +205,50 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"**Streaming Request**"
|
"### Streaming Request"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Streaming mode test\n",
|
||||||
|
"print_highlight(\"Streaming response:\")\n",
|
||||||
|
"response_stream = client.chat.completions.create(\n",
|
||||||
|
" model=model_name,\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" temperature=0.1,\n",
|
||||||
|
" top_p=0.95,\n",
|
||||||
|
" max_tokens=1024,\n",
|
||||||
|
" stream=True, # Enable streaming\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"texts = \"\"\n",
|
||||||
|
"tool_calls = []\n",
|
||||||
|
"name = \"\"\n",
|
||||||
|
"arguments = \"\"\n",
|
||||||
|
"for chunk in response_stream:\n",
|
||||||
|
" if chunk.choices[0].delta.content:\n",
|
||||||
|
" texts += chunk.choices[0].delta.content\n",
|
||||||
|
" if chunk.choices[0].delta.tool_calls:\n",
|
||||||
|
" tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n",
|
||||||
|
"print_highlight(\"==== Text ====\")\n",
|
||||||
|
"print(texts)\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(\"==== Tool Call ====\")\n",
|
||||||
|
"for tool_call in tool_calls:\n",
|
||||||
|
" print(tool_call)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Handle Tools\n",
|
||||||
|
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -251,21 +259,16 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Parse and combine function call arguments\n",
|
"# Parse and combine function call arguments\n",
|
||||||
"arguments = []\n",
|
"arguments = []\n",
|
||||||
"for chunk in chunks:\n",
|
"for tool_call in tool_calls:\n",
|
||||||
" choice = chunk.choices[0]\n",
|
" if tool_call.function.name:\n",
|
||||||
" delta = choice.delta\n",
|
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
||||||
" if delta.tool_calls:\n",
|
|
||||||
" tool_call = delta.tool_calls[0]\n",
|
|
||||||
" if tool_call.function.name:\n",
|
|
||||||
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" if tool_call.function.arguments:\n",
|
" if tool_call.function.arguments:\n",
|
||||||
" arguments.append(tool_call.function.arguments)\n",
|
" arguments.append(tool_call.function.arguments)\n",
|
||||||
" print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Combine all fragments into a single JSON string\n",
|
"# Combine all fragments into a single JSON string\n",
|
||||||
"full_arguments = \"\".join(arguments)\n",
|
"full_arguments = \"\".join(arguments)\n",
|
||||||
"print_highlight(f\"Final streamed function call arguments: {full_arguments}\")"
|
"print_highlight(f\"streamed function call arguments: {full_arguments}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -342,13 +345,16 @@
|
|||||||
"final_response = client.chat.completions.create(\n",
|
"final_response = client.chat.completions.create(\n",
|
||||||
" model=model_name,\n",
|
" model=model_name,\n",
|
||||||
" messages=messages,\n",
|
" messages=messages,\n",
|
||||||
" temperature=0.8,\n",
|
" temperature=0.1,\n",
|
||||||
" top_p=0.8,\n",
|
" top_p=0.95,\n",
|
||||||
" stream=False,\n",
|
" stream=False,\n",
|
||||||
" tools=tools,\n",
|
" tools=tools,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"print_highlight(\"Non-stream response:\")\n",
|
"print_highlight(\"Non-stream response:\")\n",
|
||||||
"print(final_response)"
|
"print(final_response)\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(\"==== Text ====\")\n",
|
||||||
|
"print(final_response.choices[0].message.content)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -368,7 +374,7 @@
|
|||||||
"import requests\n",
|
"import requests\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# generate an answer\n",
|
"# generate an answer\n",
|
||||||
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
"tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"messages = get_messages()\n",
|
"messages = get_messages()\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -380,8 +386,17 @@
|
|||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"gen_url = f\"http://localhost:{port}/generate\"\n",
|
"gen_url = f\"http://localhost:{port}/generate\"\n",
|
||||||
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
|
"gen_data = {\n",
|
||||||
|
" \"text\": input,\n",
|
||||||
|
" \"sampling_params\": {\n",
|
||||||
|
" \"skip_special_tokens\": False,\n",
|
||||||
|
" \"max_new_tokens\": 1024,\n",
|
||||||
|
" \"temperature\": 0.1,\n",
|
||||||
|
" \"top_p\": 0.95,\n",
|
||||||
|
" },\n",
|
||||||
|
"}\n",
|
||||||
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
|
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
|
||||||
|
"print_highlight(\"==== Reponse ====\")\n",
|
||||||
"print(gen_response)\n",
|
"print(gen_response)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# parse the response\n",
|
"# parse the response\n",
|
||||||
@@ -389,12 +404,16 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"function_call_input = {\n",
|
"function_call_input = {\n",
|
||||||
" \"text\": gen_response,\n",
|
" \"text\": gen_response,\n",
|
||||||
" \"tool_call_parser\": \"llama3\",\n",
|
" \"tool_call_parser\": \"qwen25\",\n",
|
||||||
" \"tools\": tools,\n",
|
" \"tools\": tools,\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
|
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
|
||||||
"function_call_response_json = function_call_response.json()\n",
|
"function_call_response_json = function_call_response.json()\n",
|
||||||
|
"\n",
|
||||||
|
"print_highlight(\"==== Text ====\")\n",
|
||||||
|
"print(function_call_response_json[\"normal_text\"])\n",
|
||||||
|
"print_highlight(\"==== Calls ====\")\n",
|
||||||
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
|
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
|
||||||
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
|
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
|
||||||
]
|
]
|
||||||
@@ -425,15 +444,15 @@
|
|||||||
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
||||||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||||||
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
||||||
"input_ids = tokenizer.apply_chat_template(\n",
|
"input_ids = tokenizer.apply_chat_template(\n",
|
||||||
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"sampling_params = {\n",
|
"sampling_params = {\n",
|
||||||
" \"max_new_tokens\": 128,\n",
|
" \"max_new_tokens\": 1024,\n",
|
||||||
" \"temperature\": 0.3,\n",
|
" \"temperature\": 0.1,\n",
|
||||||
" \"top_p\": 0.95,\n",
|
" \"top_p\": 0.95,\n",
|
||||||
" \"skip_special_tokens\": False,\n",
|
" \"skip_special_tokens\": False,\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
@@ -461,10 +480,10 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
|
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n",
|
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"qwen25\")\n",
|
||||||
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
|
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"\\n=== Parsing Result ===\")\n",
|
"print(\"=== Parsing Result ===\")\n",
|
||||||
"print(\"Normal text portion:\", normal_text)\n",
|
"print(\"Normal text portion:\", normal_text)\n",
|
||||||
"print(\"Function call portion:\")\n",
|
"print(\"Function call portion:\")\n",
|
||||||
"for call in calls:\n",
|
"for call in calls:\n",
|
||||||
@@ -521,5 +540,5 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,13 +128,15 @@ class BaseFormatDetector:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
def detect_and_parse(
|
||||||
|
self, text: str, tools: List[Function]
|
||||||
|
) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||||
Note that leftover_text here represents "content that this parser will not consume further".
|
Note that leftover_text here represents "content that this parser will not consume further".
|
||||||
"""
|
"""
|
||||||
action = json.loads(text)
|
action = json.loads(text)
|
||||||
return self.parse_base_json(action, tools)
|
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||||
|
|
||||||
def parse_streaming_increment(
|
def parse_streaming_increment(
|
||||||
self, new_text: str, tools: List[Function]
|
self, new_text: str, tools: List[Function]
|
||||||
@@ -322,7 +324,9 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||||
return self.bot_token in text
|
return self.bot_token in text
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
def detect_and_parse(
|
||||||
|
self, text: str, tools: List[Function]
|
||||||
|
) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
@@ -330,15 +334,17 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
:param tools: List of available tools.
|
:param tools: List of available tools.
|
||||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
"""
|
"""
|
||||||
if "<tool_call>" not in text:
|
idx = text.find(self.bot_token)
|
||||||
return []
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
if self.bot_token not in text:
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
|
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
||||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||||
calls = []
|
calls = []
|
||||||
for match_result in match_result_list:
|
for match_result in match_result_list:
|
||||||
match_result = json.loads(match_result)
|
match_result = json.loads(match_result)
|
||||||
calls.extend(self.parse_base_json(match_result, tools))
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
return calls
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
|
||||||
class MistralDetector(BaseFormatDetector):
|
class MistralDetector(BaseFormatDetector):
|
||||||
@@ -374,7 +380,9 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
def detect_and_parse(
|
||||||
|
self, text: str, tools: List[Function]
|
||||||
|
) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
@@ -382,6 +390,8 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
:param tools: List of available tools.
|
:param tools: List of available tools.
|
||||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||||
"""
|
"""
|
||||||
|
idx = text.find(self.bot_token)
|
||||||
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||||||
text = self._clean_text(text)
|
text = self._clean_text(text)
|
||||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||||
@@ -391,7 +401,7 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
function_call_arr = json.loads(raw_tool_call)
|
function_call_arr = json.loads(raw_tool_call)
|
||||||
for match_result in function_call_arr:
|
for match_result in function_call_arr:
|
||||||
calls.extend(self.parse_base_json(match_result, tools))
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
return calls
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
|
||||||
class Llama32Detector(BaseFormatDetector):
|
class Llama32Detector(BaseFormatDetector):
|
||||||
@@ -414,7 +424,7 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||||
"""Parse function calls from text, handling multiple JSON objects."""
|
"""Parse function calls from text, handling multiple JSON objects."""
|
||||||
if "<|python_tag|>" not in text and not text.startswith("{"):
|
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||||
return []
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
|
||||||
if "<|python_tag|>" in text:
|
if "<|python_tag|>" in text:
|
||||||
_, action_text = text.split("<|python_tag|>")
|
_, action_text = text.split("<|python_tag|>")
|
||||||
@@ -423,7 +433,6 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
|
|
||||||
# Split by semicolon and process each part
|
# Split by semicolon and process each part
|
||||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||||
|
|
||||||
all_actions = []
|
all_actions = []
|
||||||
for part in json_parts:
|
for part in json_parts:
|
||||||
try:
|
try:
|
||||||
@@ -434,12 +443,11 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
logger.warning(f"Failed to parse JSON part: {part}")
|
logger.warning(f"Failed to parse JSON part: {part}")
|
||||||
logger.warning(f"JSON parse error: {str(e)}")
|
logger.warning(f"JSON parse error: {str(e)}")
|
||||||
continue
|
continue
|
||||||
|
calls = []
|
||||||
# Only process if we found valid JSON objects
|
# Only process if we found valid JSON objects
|
||||||
if all_actions:
|
if all_actions:
|
||||||
return self.parse_base_json(all_actions, tools)
|
calls = self.parse_base_json(all_actions, tools)
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class MultiFormatParser:
|
class MultiFormatParser:
|
||||||
@@ -449,7 +457,9 @@ class MultiFormatParser:
|
|||||||
"""
|
"""
|
||||||
self.detectors = detectors
|
self.detectors = detectors
|
||||||
|
|
||||||
def parse_once(self, text: str, tools: List[Function]):
|
def parse_once(
|
||||||
|
self, text: str, tools: List[Function]
|
||||||
|
) -> Tuple[str, list[ToolCallItem]]:
|
||||||
"""
|
"""
|
||||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||||
Return: (final_text, all_calls)
|
Return: (final_text, all_calls)
|
||||||
@@ -459,15 +469,19 @@ class MultiFormatParser:
|
|||||||
final_calls = []
|
final_calls = []
|
||||||
final_normal_text = text
|
final_normal_text = text
|
||||||
for detector in self.detectors:
|
for detector in self.detectors:
|
||||||
tool_call_list = detector.detect_and_parse(text, tools)
|
parsed_result = detector.detect_and_parse(text, tools)
|
||||||
|
tool_call_list = parsed_result.calls
|
||||||
if len(tool_call_list) > 0: # parsed successfully
|
if len(tool_call_list) > 0: # parsed successfully
|
||||||
final_calls = tool_call_list
|
final_calls = tool_call_list
|
||||||
|
final_normal_text = parsed_result.normal_text
|
||||||
break
|
break
|
||||||
|
|
||||||
# leftover_text is the normal text not consumed by any Detector
|
# leftover_text is the normal text not consumed by any Detector
|
||||||
return final_normal_text, final_calls
|
return final_normal_text, final_calls
|
||||||
|
|
||||||
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Function]
|
||||||
|
) -> Tuple[str, list[ToolCallItem]]:
|
||||||
"""
|
"""
|
||||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||||
and merge their produced normal_text/calls to return.
|
and merge their produced normal_text/calls to return.
|
||||||
@@ -532,7 +546,7 @@ class FunctionCallParser:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def parse_non_stream(self, full_text: str):
|
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||||
"""
|
"""
|
||||||
Non-streaming call: one-time parsing
|
Non-streaming call: one-time parsing
|
||||||
"""
|
"""
|
||||||
@@ -541,7 +555,7 @@ class FunctionCallParser:
|
|||||||
)
|
)
|
||||||
return full_normal_text, calls
|
return full_normal_text, calls
|
||||||
|
|
||||||
def parse_stream_chunk(self, chunk_text: str):
|
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
|
||||||
"""
|
"""
|
||||||
Streaming call: incremental parsing
|
Streaming call: incremental parsing
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1130,7 +1130,7 @@ def v1_chat_generate_response(
|
|||||||
finish_reason["type"] = "tool_calls"
|
finish_reason["type"] = "tool_calls"
|
||||||
finish_reason["matched"] = None
|
finish_reason["matched"] = None
|
||||||
try:
|
try:
|
||||||
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
text, call_info_list = parser.parse_non_stream(text)
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=str(call_info.tool_index),
|
id=str(call_info.tool_index),
|
||||||
@@ -1153,9 +1153,9 @@ def v1_chat_generate_response(
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": text if tool_calls is None else None,
|
"content": text if text else None,
|
||||||
"tool_calls": tool_calls,
|
"tool_calls": tool_calls,
|
||||||
"reasoning_content": reasoning_text,
|
"reasoning_content": reasoning_text if reasoning_text else None,
|
||||||
},
|
},
|
||||||
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
||||||
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
||||||
@@ -1170,9 +1170,9 @@ def v1_chat_generate_response(
|
|||||||
index=idx,
|
index=idx,
|
||||||
message=ChatMessage(
|
message=ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=text if tool_calls is None else None,
|
content=text if text else None,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
reasoning_content=reasoning_text,
|
reasoning_content=reasoning_text if reasoning_text else None,
|
||||||
),
|
),
|
||||||
logprobs=choice_logprobs,
|
logprobs=choice_logprobs,
|
||||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||||
@@ -1317,9 +1317,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
tokenizer_manager.server_args.reasoning_parser
|
tokenizer_manager.server_args.reasoning_parser
|
||||||
and request.separate_reasoning
|
and request.separate_reasoning
|
||||||
):
|
):
|
||||||
delta = DeltaMessage(role="assistant", reasoning_content="")
|
delta = DeltaMessage(
|
||||||
|
role="assistant", reasoning_content=None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
delta = DeltaMessage(role="assistant", content="")
|
delta = DeltaMessage(role="assistant", content=None)
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=index,
|
index=index,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
@@ -1362,7 +1364,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if reasoning_text:
|
if reasoning_text:
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=index,
|
index=index,
|
||||||
delta=DeltaMessage(reasoning_content=reasoning_text),
|
delta=DeltaMessage(
|
||||||
|
reasoning_content=(
|
||||||
|
reasoning_text if reasoning_text else None
|
||||||
|
)
|
||||||
|
),
|
||||||
finish_reason=(
|
finish_reason=(
|
||||||
None
|
None
|
||||||
if finish_reason_type
|
if finish_reason_type
|
||||||
@@ -1396,7 +1402,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if normal_text:
|
if normal_text:
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=index,
|
index=index,
|
||||||
delta=DeltaMessage(content=normal_text),
|
delta=DeltaMessage(
|
||||||
|
content=normal_text if normal_text else None
|
||||||
|
),
|
||||||
finish_reason=(
|
finish_reason=(
|
||||||
None
|
None
|
||||||
if finish_reason_type
|
if finish_reason_type
|
||||||
@@ -1468,7 +1476,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
# No tool calls => just treat this as normal text
|
# No tool calls => just treat this as normal text
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=index,
|
index=index,
|
||||||
delta=DeltaMessage(content=delta),
|
delta=DeltaMessage(content=delta if delta else None),
|
||||||
finish_reason=(
|
finish_reason=(
|
||||||
None
|
None
|
||||||
if finish_reason_type and len(finish_reason_type) == 0
|
if finish_reason_type and len(finish_reason_type) == 0
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
ret_num_top_logprobs == logprobs
|
ret_num_top_logprobs == logprobs
|
||||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
|
|
||||||
assert isinstance(data.content, str)
|
assert isinstance(data.content, str) or response.choices[0].finish_reason
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user