feat: append more comprehensive fields in messages instead of merely role and content (#5996)
This commit is contained in:
@@ -38,7 +38,9 @@
|
||||
" from patch import launch_server_cmd\n",
|
||||
"else:\n",
|
||||
" from sglang.utils import launch_server_cmd\n",
|
||||
" import nest_asyncio\n",
|
||||
"\n",
|
||||
" nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n",
|
||||
@@ -164,7 +166,7 @@
|
||||
"response_non_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.1,\n",
|
||||
" temperature=0,\n",
|
||||
" top_p=0.95,\n",
|
||||
" max_tokens=1024,\n",
|
||||
" stream=False, # Non-streaming\n",
|
||||
@@ -219,7 +221,7 @@
|
||||
"response_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.1,\n",
|
||||
" temperature=0,\n",
|
||||
" top_p=0.95,\n",
|
||||
" max_tokens=1024,\n",
|
||||
" stream=True, # Enable streaming\n",
|
||||
@@ -309,22 +311,23 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"call_data = json.loads(full_arguments)\n",
|
||||
"\n",
|
||||
"messages.append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"\",\n",
|
||||
" \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"messages.append(response_non_stream.choices[0].message)\n",
|
||||
"\n",
|
||||
"# Call the corresponding tool function\n",
|
||||
"tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
|
||||
"tool_call = messages[-1].tool_calls[0]\n",
|
||||
"tool_name = tool_call.function.name\n",
|
||||
"tool_to_call = available_tools[tool_name]\n",
|
||||
"result = tool_to_call(**call_data)\n",
|
||||
"result = tool_to_call(**(json.loads(tool_call.function.arguments)))\n",
|
||||
"print_highlight(f\"Function call result: {result}\")\n",
|
||||
"messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
|
||||
"# messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
|
||||
"messages.append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"tool\",\n",
|
||||
" \"tool_call_id\": tool_call.id,\n",
|
||||
" \"content\": str(result),\n",
|
||||
" \"name\": tool_name,\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(f\"Updated message history: {messages}\")"
|
||||
]
|
||||
@@ -345,7 +348,7 @@
|
||||
"final_response = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.1,\n",
|
||||
" temperature=0,\n",
|
||||
" top_p=0.95,\n",
|
||||
" stream=False,\n",
|
||||
" tools=tools,\n",
|
||||
@@ -391,7 +394,7 @@
|
||||
" \"sampling_params\": {\n",
|
||||
" \"skip_special_tokens\": False,\n",
|
||||
" \"max_new_tokens\": 1024,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"temperature\": 0,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
@@ -452,7 +455,7 @@
|
||||
"\n",
|
||||
"sampling_params = {\n",
|
||||
" \"max_new_tokens\": 1024,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"temperature\": 0,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" \"skip_special_tokens\": False,\n",
|
||||
"}\n",
|
||||
@@ -540,14 +543,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
|
||||
"from sglang.test.test_utils import is_in_ci\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if is_in_ci():\n",
|
||||
" from patch import launch_server_cmd\n",
|
||||
"else:\n",
|
||||
" from sglang.utils import launch_server_cmd\n",
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \" python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1\" # llama-3.2-1b-instruct\n",
|
||||
@@ -624,8 +619,8 @@
|
||||
"response_non_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" temperature=0,\n",
|
||||
" top_p=0.9,\n",
|
||||
" stream=False, # Non-streaming\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
@@ -635,8 +630,8 @@
|
||||
"response_stream = client.chat.completions.create(\n",
|
||||
" model=model_name,\n",
|
||||
" messages=messages,\n",
|
||||
" temperature=0.8,\n",
|
||||
" top_p=0.8,\n",
|
||||
" temperature=0,\n",
|
||||
" top_p=0.9,\n",
|
||||
" stream=True,\n",
|
||||
" tools=tools,\n",
|
||||
")\n",
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"""Conversion between OpenAI APIs and native SRT APIs"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -970,17 +971,19 @@ def v1_chat_generate_request(
|
||||
for message in request.messages:
|
||||
if message.content is None:
|
||||
message.content = ""
|
||||
if isinstance(message.content, str):
|
||||
openai_compatible_messages.append(
|
||||
{"role": message.role, "content": message.content}
|
||||
)
|
||||
msg_dict = message.dict()
|
||||
if isinstance(msg_dict.get("content"), list):
|
||||
for chunk in msg_dict["content"]:
|
||||
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
||||
new_msg = msg_dict.copy()
|
||||
new_msg["content"] = chunk["text"]
|
||||
new_msg = {
|
||||
k: v for k, v in new_msg.items() if v is not None
|
||||
}
|
||||
openai_compatible_messages.append(new_msg)
|
||||
else:
|
||||
content_list = message.dict()["content"]
|
||||
for content in content_list:
|
||||
if content["type"] == "text":
|
||||
openai_compatible_messages.append(
|
||||
{"role": message.role, "content": content["text"]}
|
||||
)
|
||||
msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
|
||||
openai_compatible_messages.append(msg_dict)
|
||||
if (
|
||||
openai_compatible_messages
|
||||
and openai_compatible_messages[-1]["role"] == "assistant"
|
||||
@@ -1290,7 +1293,8 @@ def v1_chat_generate_response(
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(call_info.tool_index),
|
||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
||||
index=call_info.tool_index,
|
||||
function=FunctionResponse(
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
@@ -1406,6 +1410,7 @@ async def v1_chat_completions(
|
||||
reasoning_parser_dict = {}
|
||||
|
||||
async def generate_stream_resp():
|
||||
tool_call_first = True
|
||||
is_firsts = {}
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
@@ -1572,7 +1577,6 @@ async def v1_chat_completions(
|
||||
# 2) if we found calls, we output them as separate chunk(s)
|
||||
for call_item in calls:
|
||||
# transform call_item -> FunctionResponse + ToolCall
|
||||
|
||||
if finish_reason_type == "stop":
|
||||
latest_delta_len = 0
|
||||
if isinstance(call_item.parameters, str):
|
||||
@@ -1595,15 +1599,19 @@ async def v1_chat_completions(
|
||||
call_item.parameters = remaining_call
|
||||
|
||||
finish_reason_type = "tool_calls"
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=str(call_item.tool_index),
|
||||
id=(
|
||||
f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
|
||||
if tool_call_first
|
||||
else None
|
||||
),
|
||||
index=call_item.tool_index,
|
||||
function=FunctionResponse(
|
||||
name=call_item.name,
|
||||
arguments=call_item.parameters,
|
||||
),
|
||||
)
|
||||
tool_call_first = False
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(tool_calls=[tool_call]),
|
||||
|
||||
@@ -250,9 +250,29 @@ ChatCompletionMessageContentPart = Union[
|
||||
]
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call response."""
|
||||
|
||||
id: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionResponse
|
||||
|
||||
|
||||
class ChatCompletionMessageGenericParam(BaseModel):
|
||||
role: Literal["system", "assistant", "tool"]
|
||||
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
||||
tool_call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionMessageUserParam(BaseModel):
|
||||
@@ -378,22 +398,6 @@ class ChatCompletionRequest(BaseModel):
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call response."""
|
||||
|
||||
id: str
|
||||
index: Optional[int] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionResponse
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user