Constraint Decoding: Tool call with text (#4067)
This commit is contained in:
@@ -41,7 +41,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --host 0.0.0.0\" # llama3\n",
|
||||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n",
|
||||
")\n",
|
||||
"wait_for_server(f\"http://localhost:{port}\")"
|
||||
]
|
||||
@@ -55,7 +55,7 @@
|
||||
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
||||
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||||
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)."
|
||||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -121,7 +121,7 @@
|
||||
" return [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n",
|
||||
" \"content\": \"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\",\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
@@ -164,63 +164,28 @@
|
||||
"response_non_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" temperature=0.1,\n",
|
||||
" top_p=0.95,\n",
|
||||
" max_tokens=1024,\n",
|
||||
" stream=False, # Non-streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"print_highlight(\"Non-stream response:\")\n",
|
||||
"print(response_non_stream)"
|
||||
"print(response_non_stream)\n",
|
||||
"print_highlight(\"==== content ====\")\n",
|
||||
"print(response_non_stream.choices[0].message.content)\n",
|
||||
"print_highlight(\"==== tool_calls ====\")\n",
|
||||
"print(response_non_stream.choices[0].message.tool_calls)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming Request"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming mode test\n",
|
||||
"print_highlight(\"Streaming response:\")\n",
|
||||
"response_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" stream=True, # Enable streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chunks = []\n",
|
||||
"for chunk in response_stream:\n",
|
||||
" chunks.append(chunk)\n",
|
||||
" if chunk.choices[0].delta.tool_calls:\n",
|
||||
" print(chunk.choices[0].delta.tool_calls[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Handle Tool Calls\n",
|
||||
"\n",
|
||||
"#### Handle Tools\n",
|
||||
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Non-Streaming Request**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -240,7 +205,50 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Streaming Request**"
|
||||
"### Streaming Request"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming mode test\n",
|
||||
"print_highlight(\"Streaming response:\")\n",
|
||||
"response_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.1,\n",
|
||||
" top_p=0.95,\n",
|
||||
" max_tokens=1024,\n",
|
||||
" stream=True, # Enable streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"texts = \"\"\n",
|
||||
"tool_calls = []\n",
|
||||
"name = \"\"\n",
|
||||
"arguments = \"\"\n",
|
||||
"for chunk in response_stream:\n",
|
||||
" if chunk.choices[0].delta.content:\n",
|
||||
" texts += chunk.choices[0].delta.content\n",
|
||||
" if chunk.choices[0].delta.tool_calls:\n",
|
||||
" tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n",
|
||||
"print_highlight(\"==== Text ====\")\n",
|
||||
"print(texts)\n",
|
||||
"\n",
|
||||
"print_highlight(\"==== Tool Call ====\")\n",
|
||||
"for tool_call in tool_calls:\n",
|
||||
" print(tool_call)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Handle Tools\n",
|
||||
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -251,21 +259,16 @@
|
||||
"source": [
|
||||
"# Parse and combine function call arguments\n",
|
||||
"arguments = []\n",
|
||||
"for chunk in chunks:\n",
|
||||
" choice = chunk.choices[0]\n",
|
||||
" delta = choice.delta\n",
|
||||
" if delta.tool_calls:\n",
|
||||
" tool_call = delta.tool_calls[0]\n",
|
||||
" if tool_call.function.name:\n",
|
||||
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
||||
"for tool_call in tool_calls:\n",
|
||||
" if tool_call.function.name:\n",
|
||||
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
||||
"\n",
|
||||
" if tool_call.function.arguments:\n",
|
||||
" arguments.append(tool_call.function.arguments)\n",
|
||||
" print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
|
||||
" if tool_call.function.arguments:\n",
|
||||
" arguments.append(tool_call.function.arguments)\n",
|
||||
"\n",
|
||||
"# Combine all fragments into a single JSON string\n",
|
||||
"full_arguments = \"\".join(arguments)\n",
|
||||
"print_highlight(f\"Final streamed function call arguments: {full_arguments}\")"
|
||||
"print_highlight(f\"streamed function call arguments: {full_arguments}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -342,13 +345,16 @@
|
||||
"final_response = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" temperature=0.1,\n",
|
||||
" top_p=0.95,\n",
|
||||
" stream=False,\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
"print_highlight(\"Non-stream response:\")\n",
|
||||
"print(final_response)"
|
||||
"print(final_response)\n",
|
||||
"\n",
|
||||
"print_highlight(\"==== Text ====\")\n",
|
||||
"print(final_response.choices[0].message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -368,7 +374,7 @@
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"# generate an answer\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||||
"\n",
|
||||
"messages = get_messages()\n",
|
||||
"\n",
|
||||
@@ -380,8 +386,17 @@
|
||||
")\n",
|
||||
"\n",
|
||||
"gen_url = f\"http://localhost:{port}/generate\"\n",
|
||||
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
|
||||
"gen_data = {\n",
|
||||
" \"text\": input,\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"skip_special_tokens\": False,\n",
|
||||
" \"max_new_tokens\": 1024,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
|
||||
"print_highlight(\"==== Reponse ====\")\n",
|
||||
"print(gen_response)\n",
|
||||
"\n",
|
||||
"# parse the response\n",
|
||||
@@ -389,12 +404,16 @@
|
||||
"\n",
|
||||
"function_call_input = {\n",
|
||||
" \"text\": gen_response,\n",
|
||||
" \"tool_call_parser\": \"llama3\",\n",
|
||||
" \"tool_call_parser\": \"qwen25\",\n",
|
||||
" \"tools\": tools,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
|
||||
"function_call_response_json = function_call_response.json()\n",
|
||||
"\n",
|
||||
"print_highlight(\"==== Text ====\")\n",
|
||||
"print(function_call_response_json[\"normal_text\"])\n",
|
||||
"print_highlight(\"==== Calls ====\")\n",
|
||||
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
|
||||
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
|
||||
]
|
||||
@@ -425,15 +444,15 @@
|
||||
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
||||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||
"\n",
|
||||
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||||
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
||||
"input_ids = tokenizer.apply_chat_template(\n",
|
||||
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"sampling_params = {\n",
|
||||
" \"max_new_tokens\": 128,\n",
|
||||
" \"temperature\": 0.3,\n",
|
||||
" \"max_new_tokens\": 1024,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" \"skip_special_tokens\": False,\n",
|
||||
"}\n",
|
||||
@@ -461,10 +480,10 @@
|
||||
"\n",
|
||||
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
|
||||
"\n",
|
||||
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n",
|
||||
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"qwen25\")\n",
|
||||
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
|
||||
"\n",
|
||||
"print(\"\\n=== Parsing Result ===\")\n",
|
||||
"print(\"=== Parsing Result ===\")\n",
|
||||
"print(\"Normal text portion:\", normal_text)\n",
|
||||
"print(\"Function call portion:\")\n",
|
||||
"for call in calls:\n",
|
||||
@@ -521,5 +540,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user