Support v1/responses and use harmony in serving_chat (#8837)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -29,6 +29,7 @@ runtime_common = [
|
||||
"modelscope",
|
||||
"msgspec",
|
||||
"ninja",
|
||||
"openai-harmony==0.0.3",
|
||||
"orjson",
|
||||
"outlines==0.1.11",
|
||||
"packaging",
|
||||
@@ -96,7 +97,7 @@ srt_cpu = ["sglang[runtime_common]", "einops"]
|
||||
# https://vllm-ascend.readthedocs.io/en/latest/installation.html
|
||||
srt_npu = ["sglang[runtime_common]"]
|
||||
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
openai = ["openai>=1.99.1", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
|
||||
|
||||
244
python/sglang/srt/entrypoints/context.py
Normal file
244
python/sglang/srt/entrypoints/context.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copied from vLLM
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
except ImportError:
|
||||
logger.warning("Ignoring mcp import error")
|
||||
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from sglang.srt.entrypoints.harmony_utils import (
|
||||
get_encoding,
|
||||
get_streamable_parser_for_assistant,
|
||||
render_for_completion,
|
||||
)
|
||||
from sglang.srt.entrypoints.tool import Tool
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(self) -> list[Message]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_for_completion(self) -> list[int]:
|
||||
pass
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
tool_sessions: dict[str, Union["ClientSession", Tool]],
|
||||
):
|
||||
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
|
||||
# when demo.
|
||||
self._messages = messages
|
||||
self.tool_sessions = tool_sessions
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
# TODO
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
if isinstance(output, dict) and "output_ids" in output:
|
||||
output_token_ids = output["output_ids"]
|
||||
|
||||
# TODO: REMOVE here:
|
||||
# Very hacky, find the first occurrence of token 200006 and cut from there
|
||||
try:
|
||||
start_index = output_token_ids.index(200006)
|
||||
output_token_ids = output_token_ids[start_index:]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
output_msgs = self.parser.messages
|
||||
|
||||
meta_info = output["meta_info"]
|
||||
|
||||
if isinstance(meta_info, dict):
|
||||
if "prompt_token_ids" in meta_info:
|
||||
self.num_prompt_tokens = meta_info["prompt_tokens"]
|
||||
if "cached_tokens" in meta_info:
|
||||
self.num_cached_tokens = meta_info["cached_tokens"]
|
||||
if "completion_tokens" in meta_info:
|
||||
self.num_output_tokens += meta_info["completion_tokens"]
|
||||
|
||||
else:
|
||||
output_msgs = output
|
||||
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (
|
||||
recipient.startswith("browser.") or recipient.startswith("python")
|
||||
)
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
return []
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self.tool_sessions["browser"], last_msg
|
||||
)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self.tool_sessions["python"], last_msg
|
||||
)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
return render_for_completion(self.messages)
|
||||
|
||||
async def call_search_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
|
||||
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
"code": last_msg.content[0].text,
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name="python")
|
||||
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_output = None
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.encoding = get_encoding()
|
||||
self.last_tok = None
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self.parser.messages
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
if isinstance(output, dict) and "output_ids" in output:
|
||||
# RequestOutput from SGLang with outputs
|
||||
output_token_ids = output["output_ids"]
|
||||
|
||||
# TODO: REMOVE here:
|
||||
# Very hacky, find the first occurrence of token 200006 and cut from there
|
||||
# Find the first occurrence of token 200006 and cut from there
|
||||
try:
|
||||
start_index = output_token_ids.index(200006)
|
||||
output_token_ids = output_token_ids[start_index:]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
|
||||
else:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
msg = output[0]
|
||||
# Sometimes the recipient is not set for tool messages,
|
||||
# so we set it to "assistant"
|
||||
if msg.author.role == Role.TOOL and msg.recipient is None:
|
||||
msg.recipient = "assistant"
|
||||
toks = self.encoding.render(msg)
|
||||
for tok in toks:
|
||||
self.parser.process(tok)
|
||||
self.last_tok = toks[-1]
|
||||
|
||||
def is_expecting_start(self) -> bool:
|
||||
return self.parser.state == StreamState.EXPECT_START
|
||||
|
||||
def is_assistant_action_turn(self) -> bool:
|
||||
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
# now this list of tokens as next turn's starting tokens
|
||||
# `<|start|>assistant``,
|
||||
# we need to process them in parser.
|
||||
rendered_tokens = super().render_for_completion()
|
||||
|
||||
last_n = -1
|
||||
to_process = []
|
||||
while rendered_tokens[last_n] != self.last_tok:
|
||||
to_process.append(rendered_tokens[last_n])
|
||||
last_n -= 1
|
||||
for tok in reversed(to_process):
|
||||
self.parser.process(tok)
|
||||
|
||||
return rendered_tokens
|
||||
370
python/sglang/srt/entrypoints/harmony_utils.py
Normal file
370
python/sglang/srt/entrypoints/harmony_utils.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import datetime
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseOutputItem,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ActionFind,
|
||||
ActionOpenPage,
|
||||
ActionSearch,
|
||||
ResponseFunctionWebSearch,
|
||||
)
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import (
|
||||
Author,
|
||||
Conversation,
|
||||
DeveloperContent,
|
||||
HarmonyEncodingName,
|
||||
Message,
|
||||
ReasoningEffort,
|
||||
Role,
|
||||
StreamableParser,
|
||||
SystemContent,
|
||||
TextContent,
|
||||
ToolDescription,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import ResponseInputOutputItem
|
||||
from sglang.srt.utils import random_uuid
|
||||
|
||||
REASONING_EFFORT = {
|
||||
"high": ReasoningEffort.HIGH,
|
||||
"medium": ReasoningEffort.MEDIUM,
|
||||
"low": ReasoningEffort.LOW,
|
||||
}
|
||||
|
||||
_harmony_encoding = None
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _harmony_encoding
|
||||
if _harmony_encoding is None:
|
||||
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return _harmony_encoding
|
||||
|
||||
|
||||
def get_system_message(
|
||||
model_identity: Optional[str] = None,
|
||||
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
|
||||
start_date: Optional[str] = None,
|
||||
browser_description: Optional[str] = None,
|
||||
python_description: Optional[str] = None,
|
||||
) -> Message:
|
||||
sys_msg_content = SystemContent.new()
|
||||
if model_identity is not None:
|
||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||
if reasoning_effort is not None:
|
||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||
REASONING_EFFORT[reasoning_effort]
|
||||
)
|
||||
if start_date is None:
|
||||
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
|
||||
if browser_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
if python_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||
return sys_msg
|
||||
|
||||
|
||||
def get_developer_message(
|
||||
instructions: Optional[str] = None, tools: Optional[list[Tool]] = None
|
||||
) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if instructions is not None:
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
pass
|
||||
elif tool.type == "function":
|
||||
function_tools.append(tool)
|
||||
else:
|
||||
raise ValueError(f"tool type {tool.type} not supported")
|
||||
if function_tools:
|
||||
function_tool_descriptions = [
|
||||
ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
)
|
||||
for tool in function_tools
|
||||
]
|
||||
dev_msg_content = dev_msg_content.with_function_tools(
|
||||
function_tool_descriptions
|
||||
)
|
||||
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
||||
return dev_msg
|
||||
|
||||
|
||||
def get_user_message(content: str) -> Message:
|
||||
return Message.from_role_and_content(Role.USER, content)
|
||||
|
||||
|
||||
def parse_response_input(
|
||||
response_msg: ResponseInputOutputItem,
|
||||
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]],
|
||||
) -> Message:
|
||||
if not isinstance(response_msg, dict):
|
||||
response_msg = response_msg.model_dump()
|
||||
if "type" not in response_msg or response_msg["type"] == "message":
|
||||
role = response_msg["role"]
|
||||
content = response_msg["content"]
|
||||
if role == "system":
|
||||
# User is trying to set a system message. Change it to:
|
||||
# <|start|>developer<|message|># Instructions
|
||||
# {instructions}<|end|>
|
||||
role = "developer"
|
||||
text_prefix = "Instructions:\n"
|
||||
else:
|
||||
text_prefix = ""
|
||||
if isinstance(content, str):
|
||||
msg = Message.from_role_and_content(role, text_prefix + content)
|
||||
else:
|
||||
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
elif response_msg["type"] == "function_call_output":
|
||||
call_id = response_msg["call_id"]
|
||||
call_response: Optional[ResponseFunctionToolCall] = None
|
||||
for prev_response in reversed(prev_responses):
|
||||
if (
|
||||
isinstance(prev_response, ResponseFunctionToolCall)
|
||||
and prev_response.call_id == call_id
|
||||
):
|
||||
call_response = prev_response
|
||||
break
|
||||
if call_response is None:
|
||||
raise ValueError(f"No call message found for {call_id}")
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
||||
response_msg["output"],
|
||||
)
|
||||
elif response_msg["type"] == "reasoning":
|
||||
content = response_msg["content"]
|
||||
assert len(content) == 1
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
||||
elif response_msg["type"] == "function_call":
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{response_msg['name']}")
|
||||
msg = msg.with_content_type("json")
|
||||
else:
|
||||
raise ValueError(f"Unknown input type: {response_msg['type']}")
|
||||
return msg
|
||||
|
||||
|
||||
def parse_response_output(output: ResponseOutputItem) -> Message:
|
||||
if isinstance(output, ResponseOutputMessage):
|
||||
role = output.role
|
||||
contents = [TextContent(text=c.text) for c in output.content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
return msg
|
||||
elif isinstance(output, ResponseFunctionToolCall):
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, output.arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(output.name)
|
||||
msg = msg.with_content_type("json")
|
||||
return msg
|
||||
else:
|
||||
raise ValueError(f"Unknown output type: {type(output)}")
|
||||
|
||||
|
||||
def parse_chat_input(chat_msg) -> Message:
|
||||
role = chat_msg.role
|
||||
content = chat_msg.content
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.text) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
return msg
|
||||
|
||||
|
||||
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||
conversation = Conversation.from_messages(messages)
|
||||
token_ids = get_encoding().render_conversation_for_completion(
|
||||
conversation, Role.ASSISTANT
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
||||
def get_stop_tokens_for_assistant_actions() -> list[int]:
|
||||
return get_encoding().stop_tokens_for_assistant_actions()
|
||||
|
||||
|
||||
def get_streamable_parser_for_assistant() -> StreamableParser:
|
||||
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
||||
|
||||
|
||||
def parse_output_message(message: Message):
|
||||
if message.author.role != "assistant":
|
||||
# This is a message from a tool to the assistant (e.g., search result).
|
||||
# Don't include it in the final output for now. This aligns with
|
||||
# OpenAI's behavior on models like o4-mini.
|
||||
return []
|
||||
|
||||
output_items = []
|
||||
recipient = message.recipient
|
||||
if recipient is not None and recipient.startswith("browser."):
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
browser_call = json.loads(content.text)
|
||||
# TODO: translate to url properly!
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search"
|
||||
)
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(
|
||||
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
|
||||
)
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(
|
||||
pattern=browser_call["pattern"],
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
web_search_item = ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
output_items.append(web_search_item)
|
||||
elif message.channel == "analysis":
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
type="reasoning",
|
||||
summary=[],
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=content.text, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
elif message.channel == "commentary":
|
||||
if message.recipient.startswith("functions."):
|
||||
function_name = message.recipient.split(".")[-1]
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"ft_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
elif message.recipient.startswith("python") or message.recipient.startswith(
|
||||
"browser"
|
||||
):
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
type="reasoning",
|
||||
summary=[],
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=content.text, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown recipient: {message.recipient}")
|
||||
elif message.channel == "final":
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
output_items.append(text_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown channel: {message.channel}")
|
||||
return output_items
|
||||
|
||||
|
||||
def parse_remaining_state(parser: StreamableParser):
|
||||
if not parser.current_content:
|
||||
return []
|
||||
if parser.current_role != Role.ASSISTANT:
|
||||
return []
|
||||
current_recipient = parser.current_recipient
|
||||
if current_recipient is not None and current_recipient.startswith("browser."):
|
||||
return []
|
||||
|
||||
if parser.current_channel == "analysis":
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
type="reasoning",
|
||||
summary=[],
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
return [reasoning_item]
|
||||
elif parser.current_channel == "final":
|
||||
output_text = ResponseOutputText(
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
return [text_item]
|
||||
return []
|
||||
|
||||
|
||||
def parse_output_into_messages(token_ids: Iterable[int]):
|
||||
parser = get_streamable_parser_for_assistant()
|
||||
for token_id in token_ids:
|
||||
parser.process(token_id)
|
||||
return parser
|
||||
@@ -32,6 +32,7 @@ from typing import AsyncIterator, Callable, Dict, Optional
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
@@ -56,6 +57,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ResponsesRequest,
|
||||
ScoringRequest,
|
||||
V1RerankReqInput,
|
||||
)
|
||||
@@ -147,6 +149,37 @@ async def lifespan(fast_api_app: FastAPI):
|
||||
)
|
||||
|
||||
server_args: ServerArgs = fast_api_app.server_args
|
||||
|
||||
tool_server = None
|
||||
if server_args.tool_server == "demo":
|
||||
from sglang.srt.entrypoints.openai.tool_server import DemoToolServer
|
||||
|
||||
tool_server = DemoToolServer()
|
||||
elif server_args.tool_server:
|
||||
from sglang.srt.entrypoints.openai.tool_server import MCPToolServer
|
||||
|
||||
tool_server = MCPToolServer()
|
||||
await tool_server.add_tool_server(server_args.tool_server)
|
||||
|
||||
try:
|
||||
from sglang.srt.entrypoints.openai.serving_responses import (
|
||||
OpenAIServingResponses,
|
||||
)
|
||||
|
||||
fast_api_app.state.openai_serving_responses = OpenAIServingResponses(
|
||||
_global_state.tokenizer_manager,
|
||||
_global_state.template_manager,
|
||||
enable_prompt_tokens_details=True,
|
||||
enable_force_include_usage=True,
|
||||
tool_server=tool_server,
|
||||
)
|
||||
except Exception as e:
|
||||
# print stack trace
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.warning(f"Can not initialize OpenAIServingResponses, error: {e}")
|
||||
|
||||
if server_args.warmups is not None:
|
||||
await execute_warmups(
|
||||
server_args.disaggregation_mode,
|
||||
@@ -843,6 +876,42 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/responses", dependencies=[Depends(validate_json_request)])
|
||||
async def v1_responses_request(request: dict, raw_request: Request):
|
||||
"""Endpoint for the responses API with reasoning support."""
|
||||
|
||||
request_obj = ResponsesRequest(**request)
|
||||
result = await raw_request.app.state.openai_serving_responses.create_responses(
|
||||
request_obj, raw_request
|
||||
)
|
||||
|
||||
# Handle streaming responses
|
||||
if isinstance(result, AsyncGenerator):
|
||||
return StreamingResponse(
|
||||
result,
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/v1/responses/{response_id}")
|
||||
async def v1_retrieve_responses(response_id: str, raw_request: Request):
|
||||
"""Retrieve a response by ID."""
|
||||
return await raw_request.app.state.openai_serving_responses.retrieve_responses(
|
||||
response_id
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/responses/{response_id}/cancel")
|
||||
async def v1_cancel_responses(response_id: str, raw_request: Request):
|
||||
"""Cancel a background response."""
|
||||
return await raw_request.app.state.openai_serving_responses.cancel_responses(
|
||||
response_id
|
||||
)
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
|
||||
@@ -14,9 +14,18 @@
|
||||
"""Pydantic models for OpenAI API protocol"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInputItemParam,
|
||||
ResponseOutputItem,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.tool import Tool
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
@@ -84,6 +93,7 @@ class UsageInfo(BaseModel):
|
||||
completion_tokens: Optional[int] = 0
|
||||
# only used to return cached tokens when --enable-cache-report is set
|
||||
prompt_tokens_details: Optional[Dict[str, int]] = None
|
||||
reasoning_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
@@ -428,6 +438,13 @@ class ChatCompletionRequest(BaseModel):
|
||||
default="auto", examples=["none"]
|
||||
) # noqa
|
||||
return_hidden_states: bool = False
|
||||
reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field(
|
||||
default="medium",
|
||||
description="Constrains effort on reasoning for reasoning models. "
|
||||
"'low' is the least effort, 'high' is the most effort. Reducing reasoning effort can "
|
||||
"result in faster responses and fewer tokens used on reasoning in a response. "
|
||||
"Currently only supported for OpenAI models.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -619,6 +636,196 @@ OpenAIServingRequest = Union[
|
||||
]
|
||||
|
||||
|
||||
# Response API protocol definitions
|
||||
class ResponseReasoningParam(BaseModel):
|
||||
"""Reasoning parameters for responses."""
|
||||
|
||||
effort: Optional[Literal["low", "medium", "high"]] = Field(
|
||||
default="medium",
|
||||
description="Constrains effort on reasoning for reasoning models.",
|
||||
)
|
||||
|
||||
|
||||
class ResponseTool(BaseModel):
|
||||
"""Tool definition for responses."""
|
||||
|
||||
type: Literal["web_search_preview", "code_interpreter"] = Field(
|
||||
description="Type of tool to enable"
|
||||
)
|
||||
|
||||
|
||||
ResponseInputOutputItem: TypeAlias = Union[
|
||||
ResponseInputItemParam,
|
||||
"ResponseReasoningItem",
|
||||
ResponseFunctionToolCall,
|
||||
]
|
||||
|
||||
|
||||
class ResponsesRequest(BaseModel):
|
||||
"""Request body for v1/responses endpoint."""
|
||||
|
||||
# Core OpenAI API fields (ordered by official documentation)
|
||||
background: Optional[bool] = False
|
||||
include: Optional[
|
||||
List[
|
||||
Literal[
|
||||
"code_interpreter_call.outputs",
|
||||
"computer_call_output.output.image_url",
|
||||
"file_search_call.results",
|
||||
"message.input_image.image_url",
|
||||
"message.output_text.logprobs",
|
||||
"reasoning.encrypted_content",
|
||||
]
|
||||
]
|
||||
] = None
|
||||
input: Union[str, List[ResponseInputOutputItem]]
|
||||
instructions: Optional[str] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
max_tool_calls: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
model: Optional[str] = None # Made optional to match vLLM
|
||||
parallel_tool_calls: Optional[bool] = True
|
||||
previous_response_id: Optional[str] = None
|
||||
reasoning: Optional[ResponseReasoningParam] = None
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
|
||||
store: Optional[bool] = True
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = None
|
||||
tool_choice: Literal["auto", "required", "none"] = "auto"
|
||||
tools: List[ResponseTool] = Field(default_factory=list)
|
||||
top_logprobs: Optional[int] = 0
|
||||
top_p: Optional[float] = None
|
||||
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
|
||||
user: Optional[str] = None
|
||||
|
||||
# Extra SGLang parameters
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"resp_{uuid.uuid4().hex}",
|
||||
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
|
||||
)
|
||||
priority: int = Field(default=0, description="Request priority")
|
||||
|
||||
# SGLang-specific sampling parameters
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
|
||||
# Default sampling parameters
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": -1,
|
||||
"min_p": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_params: Optional[Dict] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert to sampling parameters for generation."""
|
||||
if default_params is None:
|
||||
default_params = {}
|
||||
|
||||
# Use max_output_tokens if available, otherwise use max_tokens for backwards compatibility
|
||||
if self.max_output_tokens is not None:
|
||||
max_tokens = min(self.max_output_tokens, default_max_tokens)
|
||||
else:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
# Avoid exceed the context length by minus 1 token
|
||||
max_tokens -= 1
|
||||
|
||||
# Get parameters with defaults
|
||||
temperature = self.temperature
|
||||
if temperature is None:
|
||||
temperature = default_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
|
||||
top_p = self.top_p
|
||||
if top_p is None:
|
||||
top_p = default_params.get("top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
|
||||
params = {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"stop": self.stop,
|
||||
"top_k": self.top_k,
|
||||
"min_p": self.min_p,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
}
|
||||
|
||||
# Apply any additional default parameters
|
||||
for key, value in default_params.items():
|
||||
if key not in params or params[key] is None:
|
||||
params[key] = value
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class PromptTokenUsageInfo(BaseModel):
|
||||
"""Prompt token usage details."""
|
||||
|
||||
cached_tokens: int = 0
|
||||
|
||||
|
||||
class ResponsesResponse(BaseModel):
|
||||
"""Response body for v1/responses endpoint."""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"resp_{time.time()}")
|
||||
object: Literal["response"] = "response"
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
|
||||
output: List[
|
||||
Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]
|
||||
] = Field(default_factory=list)
|
||||
status: Literal["queued", "in_progress", "completed", "failed", "cancelled"]
|
||||
usage: Optional[UsageInfo] = None
|
||||
parallel_tool_calls: bool = True
|
||||
tool_choice: str = "auto"
|
||||
tools: List[ResponseTool] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: Any,
|
||||
model_name: str,
|
||||
created_time: int,
|
||||
output: List[
|
||||
Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]
|
||||
],
|
||||
status: str,
|
||||
usage: Optional[UsageInfo],
|
||||
) -> "ResponsesResponse":
|
||||
"""Create a response from a request."""
|
||||
return cls(
|
||||
id=request.request_id,
|
||||
created_at=created_time,
|
||||
model=model_name,
|
||||
output=output,
|
||||
status=status,
|
||||
usage=usage,
|
||||
parallel_tool_calls=request.parallel_tool_calls or True,
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
)
|
||||
|
||||
|
||||
class RequestResponseMetadata(BaseModel):
|
||||
"""Metadata for request/response tracking."""
|
||||
|
||||
request_id: str
|
||||
final_usage_info: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageProcessingResult:
|
||||
"""Result of processing chat messages and applying templates.
|
||||
@@ -645,3 +852,22 @@ class MessageProcessingResult:
|
||||
modalities: List[str]
|
||||
stop: List[str]
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
|
||||
|
||||
class ResponseReasoningTextContent(BaseModel):
|
||||
text: str
|
||||
type: Literal["reasoning_text"] = "reasoning_text"
|
||||
|
||||
|
||||
class ResponseReasoningItem(BaseModel):
|
||||
id: str
|
||||
content: list[ResponseReasoningTextContent] = Field(default_factory=list)
|
||||
summary: list = Field(default_factory=list)
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
encrypted_content: Optional[str] = None
|
||||
status: Optional[Literal["in_progress", "completed", "incomplete"]]
|
||||
|
||||
|
||||
ResponseInputOutputItem: TypeAlias = Union[
|
||||
ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall
|
||||
]
|
||||
|
||||
@@ -7,8 +7,18 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
from openai_harmony import Message as OpenAIMessage
|
||||
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.entrypoints.harmony_utils import (
|
||||
get_developer_message,
|
||||
get_stop_tokens_for_assistant_actions,
|
||||
get_streamable_parser_for_assistant,
|
||||
get_system_message,
|
||||
parse_chat_input,
|
||||
parse_output_into_messages,
|
||||
render_for_completion,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@@ -51,6 +61,26 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
):
|
||||
super().__init__(tokenizer_manager)
|
||||
self.template_manager = template_manager
|
||||
self.use_harmony = (
|
||||
self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
|
||||
)
|
||||
|
||||
if self.use_harmony:
|
||||
from sglang.srt.function_call.harmony_tool_parser import (
|
||||
HarmonyToolCallParser,
|
||||
)
|
||||
|
||||
self.harmony_tool_parser = HarmonyToolCallParser()
|
||||
|
||||
# NOTE While OpenAI's chat completion API supports browsing
|
||||
# for some models, currently vLLM doesn't support it. Please use the
|
||||
# Responses API instead.
|
||||
self.supports_browsing = False
|
||||
self.browser_tool = None
|
||||
# NOTE: Chat completion API does not support code interpreter.
|
||||
# Please use the Responses API instead.
|
||||
self.supports_code_interpreter = False
|
||||
self.python_tool = None
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "chatcmpl-"
|
||||
@@ -77,41 +107,66 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||
|
||||
# Process messages and apply chat template
|
||||
processed_messages = self._process_messages(request, is_multimodal)
|
||||
if not self.use_harmony:
|
||||
processed_messages = self._process_messages(request, is_multimodal)
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request, processed_messages.stop, processed_messages.tool_call_constraint
|
||||
)
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request,
|
||||
processed_messages.stop,
|
||||
processed_messages.tool_call_constraint,
|
||||
)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": processed_messages.prompt}
|
||||
else:
|
||||
if isinstance(processed_messages.prompt_ids, str):
|
||||
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
||||
# Handle single vs multiple requests
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": processed_messages.prompt}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||
if isinstance(processed_messages.prompt_ids, str):
|
||||
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=processed_messages.image_data,
|
||||
video_data=processed_messages.video_data,
|
||||
audio_data=processed_messages.audio_data,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=request.logprobs,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=request.top_logprobs or 0,
|
||||
stream=request.stream,
|
||||
return_text_in_logprobs=True,
|
||||
modalities=processed_messages.modalities,
|
||||
lora_path=request.lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
)
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=processed_messages.image_data,
|
||||
video_data=processed_messages.video_data,
|
||||
audio_data=processed_messages.audio_data,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=request.logprobs,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=request.top_logprobs or 0,
|
||||
stream=request.stream,
|
||||
return_text_in_logprobs=True,
|
||||
modalities=processed_messages.modalities,
|
||||
lora_path=request.lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
)
|
||||
else:
|
||||
processed_messages, prompt_ids = self._make_request_with_harmony(request)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
input_ids=prompt_ids,
|
||||
sampling_params=self._build_sampling_params(
|
||||
request,
|
||||
request.stop,
|
||||
tool_call_constraint=None,
|
||||
),
|
||||
stream=request.stream,
|
||||
return_logprob=request.logprobs,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=request.top_logprobs or 0,
|
||||
return_text_in_logprobs=True,
|
||||
lora_path=request.lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -402,6 +457,12 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
cached_tokens = {}
|
||||
hidden_states = {}
|
||||
|
||||
# Harmony tracking
|
||||
if self.use_harmony:
|
||||
harmony_parsers = [
|
||||
get_streamable_parser_for_assistant() for _ in range(request.n)
|
||||
]
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
@@ -449,14 +510,57 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Process content delta
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
delta = content["text"][len(stream_buffer) :]
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
if self.use_harmony:
|
||||
harmony_parser = harmony_parsers[index]
|
||||
|
||||
new_token_ids = content["output_ids"]
|
||||
for token_id in new_token_ids:
|
||||
harmony_parser.process(token_id)
|
||||
|
||||
is_final = harmony_parser.current_channel == "final"
|
||||
is_analysis = harmony_parser.current_channel == "analysis"
|
||||
delta = harmony_parser.last_content_delta or ""
|
||||
|
||||
if is_analysis:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(reasoning_content=delta),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta if delta else None),
|
||||
finish_reason=None,
|
||||
matched_stop=None,
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
continue
|
||||
else:
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
delta = content["text"][len(stream_buffer) :]
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
|
||||
# Handle reasoning content
|
||||
if (
|
||||
self.tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
and not self.use_harmony
|
||||
):
|
||||
reasoning_text, delta = self._process_reasoning_stream(
|
||||
index, delta, reasoning_parser_dict, content, request
|
||||
@@ -475,8 +579,27 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
if self.use_harmony and not is_final:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(reasoning_content=delta),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Handle tool calls
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
# TODO: support tool call parsing for harmony
|
||||
if (
|
||||
request.tool_choice != "none"
|
||||
and request.tools
|
||||
and not self.use_harmony
|
||||
):
|
||||
async for chunk in self._process_tool_call_stream(
|
||||
index,
|
||||
delta,
|
||||
@@ -502,7 +625,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if delta:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta if delta else None),
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=None,
|
||||
matched_stop=None,
|
||||
logprobs=choice_logprobs,
|
||||
@@ -640,6 +763,76 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
text = ret_item["text"]
|
||||
output_ids = ret_item["output_ids"]
|
||||
|
||||
if self.use_harmony:
|
||||
parser = parse_output_into_messages(output_ids)
|
||||
output_msgs = parser.messages
|
||||
if len(output_msgs) == 0:
|
||||
# The generation has stopped during reasoning.
|
||||
is_tool_call = False
|
||||
reasoning_content = parser.current_content
|
||||
final_content = None
|
||||
elif len(output_msgs) == 1:
|
||||
# The generation has stopped during final message.
|
||||
is_tool_call = False
|
||||
reasoning_content = output_msgs[0].content[0].text
|
||||
final_content = parser.current_content
|
||||
else:
|
||||
if len(output_msgs) != 2:
|
||||
raise ValueError(
|
||||
"Expected 2 output messages (reasoning and final), "
|
||||
f"but got {len(output_msgs)}."
|
||||
)
|
||||
reasoning_msg, final_msg = output_msgs
|
||||
reasoning_content = reasoning_msg.content[0].text
|
||||
final_content = final_msg.content[0].text
|
||||
is_tool_call = final_msg.recipient is not None
|
||||
|
||||
if is_tool_call:
|
||||
# Extract tool call information from final message
|
||||
tool_call = (
|
||||
self.harmony_tool_parser.extract_tool_calls_from_message(
|
||||
final_msg
|
||||
)
|
||||
)
|
||||
tool_calls = [tool_call] if tool_call else []
|
||||
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
reasoning_content=reasoning_content,
|
||||
content=None, # Tool calls don't have regular content
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
else:
|
||||
# Normal message
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
reasoning_content=reasoning_content,
|
||||
content=final_content,
|
||||
)
|
||||
|
||||
if is_tool_call:
|
||||
finish_reason_type = "tool_calls"
|
||||
elif finish_reason:
|
||||
finish_reason_type = (
|
||||
finish_reason["type"] if finish_reason else "stop"
|
||||
)
|
||||
else:
|
||||
finish_reason_type = "stop"
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
message=message,
|
||||
logprobs=choice_logprobs,
|
||||
finish_reason=finish_reason_type,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
continue
|
||||
|
||||
# Handle reasoning content
|
||||
reasoning_text = None
|
||||
@@ -978,3 +1171,33 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
return f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
return None
|
||||
|
||||
def _make_request_with_harmony(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
):
|
||||
messages: list[OpenAIMessage] = []
|
||||
|
||||
# Add system message.
|
||||
# In Chat Completion API, browsing is enabled by default if the model
|
||||
# supports it.
|
||||
assert not self.supports_browsing
|
||||
assert not self.supports_code_interpreter
|
||||
sys_msg = get_system_message(
|
||||
reasoning_effort=request.reasoning_effort,
|
||||
browser_description=None,
|
||||
python_description=None,
|
||||
)
|
||||
messages.append(sys_msg)
|
||||
|
||||
# Add developer message.
|
||||
dev_msg = get_developer_message()
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
for chat_msg in request.messages:
|
||||
messages.append(parse_chat_input(chat_msg))
|
||||
|
||||
# Render prompt token ids.
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
return messages, prompt_token_ids
|
||||
|
||||
1273
python/sglang/srt/entrypoints/openai/serving_responses.py
Normal file
1273
python/sglang/srt/entrypoints/openai/serving_responses.py
Normal file
File diff suppressed because it is too large
Load Diff
174
python/sglang/srt/entrypoints/openai/tool_server.py
Normal file
174
python/sglang/srt/entrypoints/openai/tool_server.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import ListToolsResult
|
||||
except ImportError:
|
||||
logger.warning("Ignoring mcp import error")
|
||||
|
||||
from openai_harmony import ToolDescription, ToolNamespaceConfig
|
||||
|
||||
|
||||
async def list_server_and_tools(server_url: str):
|
||||
|
||||
async with sse_client(url=server_url) as streams, ClientSession(
|
||||
*streams
|
||||
) as session:
|
||||
initialize_response = await session.initialize()
|
||||
list_tools_response = await session.list_tools()
|
||||
return initialize_response, list_tools_response
|
||||
|
||||
|
||||
def trim_schema(schema: dict) -> dict:
|
||||
# Turn JSON Schema from MCP generated into Harmony's variant.
|
||||
if "title" in schema:
|
||||
del schema["title"]
|
||||
if "default" in schema and schema["default"] is None:
|
||||
del schema["default"]
|
||||
if "anyOf" in schema:
|
||||
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
|
||||
# into "type": ["type-1", "type-2"]
|
||||
# if there's more than 1 types, also remove "null" type as Harmony will
|
||||
# just ignore it
|
||||
types = [
|
||||
type_dict["type"]
|
||||
for type_dict in schema["anyOf"]
|
||||
if type_dict["type"] != "null"
|
||||
]
|
||||
schema["type"] = types
|
||||
del schema["anyOf"]
|
||||
if "properties" in schema:
|
||||
schema["properties"] = {
|
||||
k: trim_schema(v) for k, v in schema["properties"].items()
|
||||
}
|
||||
return schema
|
||||
|
||||
|
||||
def post_process_tools_description(
|
||||
list_tools_result: "ListToolsResult",
|
||||
) -> "ListToolsResult":
|
||||
# Adapt the MCP tool result for Harmony
|
||||
for tool in list_tools_result.tools:
|
||||
tool.inputSchema = trim_schema(tool.inputSchema)
|
||||
|
||||
# Some tools schema don't need to be part of the prompt (e.g. simple text
|
||||
# in text out for Python)
|
||||
list_tools_result.tools = [
|
||||
tool
|
||||
for tool in list_tools_result.tools
|
||||
if getattr(tool.annotations, "include_in_prompt", True)
|
||||
]
|
||||
|
||||
return list_tools_result
|
||||
|
||||
|
||||
class ToolServer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def has_tool(self, tool_name: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tool_description(self, tool_name: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...
|
||||
|
||||
|
||||
class MCPToolServer(ToolServer):
|
||||
|
||||
def __init__(self):
|
||||
self.harmony_tool_descriptions = {}
|
||||
|
||||
async def add_tool_server(self, server_url: str):
|
||||
tool_urls = server_url.split(",")
|
||||
self.harmony_tool_descriptions = {}
|
||||
self.urls: dict[str, str] = {}
|
||||
for url in tool_urls:
|
||||
url = f"http://{url}/sse"
|
||||
initialize_response, list_tools_response = await list_server_and_tools(url)
|
||||
|
||||
list_tools_response = post_process_tools_description(list_tools_response)
|
||||
|
||||
tool_from_mcp = ToolNamespaceConfig(
|
||||
name=initialize_response.serverInfo.name,
|
||||
description=initialize_response.instructions,
|
||||
tools=[
|
||||
ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
)
|
||||
for tool in list_tools_response.tools
|
||||
],
|
||||
)
|
||||
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
||||
if tool_from_mcp.name not in self.urls:
|
||||
self.urls[tool_from_mcp.name] = url
|
||||
else:
|
||||
logger.warning(
|
||||
"Tool %s already exists. Ignoring duplicate tool server %s",
|
||||
tool_from_mcp.name,
|
||||
url,
|
||||
)
|
||||
|
||||
def has_tool(self, tool_name: str):
|
||||
return tool_name in self.harmony_tool_descriptions
|
||||
|
||||
def get_tool_description(self, tool_name: str):
|
||||
return self.harmony_tool_descriptions.get(tool_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_tool_session(self, tool_name: str):
|
||||
url = self.urls.get(tool_name)
|
||||
if url:
|
||||
async with sse_client(url=url) as streams, ClientSession(
|
||||
*streams
|
||||
) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
else:
|
||||
logger.warning("Tool %s not found", tool_name)
|
||||
|
||||
|
||||
class DemoToolServer(ToolServer):
|
||||
|
||||
def __init__(self):
|
||||
from sglang.srt.entrypoints.tool import (
|
||||
HarmonyBrowserTool,
|
||||
HarmonyPythonTool,
|
||||
Tool,
|
||||
)
|
||||
|
||||
self.tools: dict[str, Tool] = {}
|
||||
browser_tool = HarmonyBrowserTool()
|
||||
if browser_tool.enabled:
|
||||
self.tools["browser"] = browser_tool
|
||||
python_tool = HarmonyPythonTool()
|
||||
if python_tool.enabled:
|
||||
self.tools["python"] = python_tool
|
||||
|
||||
def has_tool(self, tool_name: str):
|
||||
return tool_name in self.tools
|
||||
|
||||
def get_tool_description(self, tool_name: str):
|
||||
if tool_name not in self.tools:
|
||||
return None
|
||||
if tool_name == "browser":
|
||||
return ToolNamespaceConfig.browser()
|
||||
elif tool_name == "python":
|
||||
return ToolNamespaceConfig.python()
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {tool_name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_tool_session(self, tool_name: str):
|
||||
yield self.tools[tool_name]
|
||||
87
python/sglang/srt/entrypoints/tool.py
Normal file
87
python/sglang/srt/entrypoints/tool.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Avoid circular import.
|
||||
from sglang.srt.entrypoints.context import ConversationContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_result(self, context: "ConversationContext") -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class HarmonyBrowserTool(Tool):
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = True
|
||||
exa_api_key = os.getenv("EXA_API_KEY")
|
||||
if not exa_api_key:
|
||||
self.enabled = False
|
||||
logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
||||
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
||||
except ImportError:
|
||||
self.enabled = False
|
||||
logger.warning_once("gpt_oss is not installed, browsing is disabled")
|
||||
return
|
||||
|
||||
browser_backend = ExaBackend(source="web", api_key=exa_api_key)
|
||||
self.browser_tool = SimpleBrowserTool(backend=browser_backend)
|
||||
logger.info_once("Browser tool initialized")
|
||||
|
||||
async def get_result(self, context: "ConversationContext") -> Any:
|
||||
from sglang.srt.entrypoints.context import HarmonyContext
|
||||
|
||||
assert isinstance(context, HarmonyContext)
|
||||
last_msg = context.messages[-1]
|
||||
tool_output_msgs = []
|
||||
async for msg in self.browser_tool.process(last_msg):
|
||||
tool_output_msgs.append(msg)
|
||||
return tool_output_msgs
|
||||
|
||||
@property
|
||||
def tool_config(self) -> Any:
|
||||
return self.browser_tool.tool_config
|
||||
|
||||
|
||||
class HarmonyPythonTool(Tool):
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = True
|
||||
|
||||
try:
|
||||
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
||||
except ImportError:
|
||||
self.enabled = False
|
||||
logger.warning_once(
|
||||
"gpt_oss is not installed, code interpreter is disabled"
|
||||
)
|
||||
return
|
||||
|
||||
self.python_tool = PythonTool()
|
||||
logger.info_once("Code interpreter tool initialized")
|
||||
|
||||
async def get_result(self, context: "ConversationContext") -> Any:
|
||||
from sglang.srt.entrypoints.context import HarmonyContext
|
||||
|
||||
assert isinstance(context, HarmonyContext)
|
||||
last_msg = context.messages[-1]
|
||||
tool_output_msgs = []
|
||||
async for msg in self.python_tool.process(last_msg):
|
||||
tool_output_msgs.append(msg)
|
||||
return tool_output_msgs
|
||||
|
||||
@property
|
||||
def tool_config(self) -> Any:
|
||||
return self.python_tool.tool_config
|
||||
130
python/sglang/srt/function_call/harmony_tool_parser.py
Normal file
130
python/sglang/srt/function_call/harmony_tool_parser.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Harmony tool call parser for processing tool calls in harmony models."""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatMessage,
|
||||
FunctionResponse,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
|
||||
class HarmonyToolCallParser:
|
||||
"""Parser for extracting tool calls from harmony model outputs."""
|
||||
|
||||
def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
|
||||
"""
|
||||
Extract tool call from a single message if it's a tool call.
|
||||
|
||||
Args:
|
||||
msg: The harmony message
|
||||
|
||||
Returns:
|
||||
ToolCall if the message is a tool call, None otherwise
|
||||
"""
|
||||
if (
|
||||
msg.channel == "commentary"
|
||||
and msg.recipient
|
||||
and msg.recipient.startswith("functions.")
|
||||
):
|
||||
function_name = msg.recipient.split(".")[-1]
|
||||
arguments = msg.content[0].text if msg.content else "{}"
|
||||
|
||||
return ToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
function=FunctionResponse(
|
||||
name=function_name,
|
||||
arguments=arguments,
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def process_streaming_chunk(
|
||||
self,
|
||||
harmony_parser,
|
||||
index: int,
|
||||
tool_call_trackers: dict,
|
||||
stream_buffers: dict,
|
||||
) -> Tuple[Optional[dict], bool, Optional[str]]:
|
||||
"""
|
||||
Process a streaming chunk for tool calls.
|
||||
|
||||
Args:
|
||||
harmony_parser: The harmony parser instance
|
||||
index: The choice index
|
||||
tool_call_trackers: Dict tracking tool calls per choice
|
||||
stream_buffers: Dict for buffering content
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_call_data, is_tool_call, delta)
|
||||
"""
|
||||
# Check if we're in a tool call
|
||||
is_tool_call = (
|
||||
harmony_parser.current_channel == "commentary"
|
||||
and harmony_parser.current_recipient
|
||||
and harmony_parser.current_recipient.startswith("functions.")
|
||||
)
|
||||
|
||||
delta = harmony_parser.last_content_delta or ""
|
||||
tool_call_data = None
|
||||
|
||||
if is_tool_call:
|
||||
# Handle tool call streaming
|
||||
function_name = harmony_parser.current_recipient.split(".")[-1]
|
||||
|
||||
# Track tool call indices per choice
|
||||
if index not in tool_call_trackers:
|
||||
tool_call_trackers[index] = {"count": 0, "current_function": None}
|
||||
|
||||
# Check if we just started a new tool call
|
||||
tool_call_tracker = tool_call_trackers[index]
|
||||
if tool_call_tracker["current_function"] != function_name:
|
||||
# New tool call started
|
||||
tool_call_tracker["current_function"] = function_name
|
||||
tool_call_index = tool_call_tracker["count"]
|
||||
tool_call_tracker["count"] += 1
|
||||
|
||||
# Store the tool call index for this function
|
||||
tool_call_key = f"{index}_{function_name}"
|
||||
stream_buffers[tool_call_key] = {
|
||||
"index": tool_call_index,
|
||||
"content": "",
|
||||
}
|
||||
|
||||
tool_call_data = {
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"index": tool_call_index,
|
||||
"function_name": function_name,
|
||||
"arguments": delta,
|
||||
"is_first_chunk": True,
|
||||
}
|
||||
else:
|
||||
# Subsequent chunks for the same tool call
|
||||
tool_call_key = f"{index}_{function_name}"
|
||||
tool_call_index = stream_buffers[tool_call_key]["index"]
|
||||
|
||||
tool_call_data = {
|
||||
"id": None,
|
||||
"index": tool_call_index,
|
||||
"function_name": None,
|
||||
"arguments": delta,
|
||||
"is_first_chunk": False,
|
||||
}
|
||||
|
||||
stream_buffers[tool_call_key]["content"] += delta
|
||||
|
||||
return tool_call_data, is_tool_call, delta
|
||||
@@ -216,7 +216,7 @@ class DetokenizerManager:
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
output_ids=None,
|
||||
output_ids=recv_obj.decode_ids,
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
|
||||
@@ -126,6 +126,9 @@ class GenerateReqInput:
|
||||
# For data parallel rank routing
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
# For background responses (OpenAI responses API)
|
||||
background: bool = False
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
@@ -560,6 +563,9 @@ class EmbeddingReqInput:
|
||||
# For cross-encoder requests
|
||||
is_cross_encoder_request: bool = False
|
||||
|
||||
# For background responses (OpenAI responses API)
|
||||
background: bool = False
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
# at least one of text, input_ids, or image should be provided
|
||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||
|
||||
@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
req.send_decode_id_offset = len(decode_ids)
|
||||
read_offsets.append(read_offset)
|
||||
if self.skip_tokenizer_init:
|
||||
output_ids.append(req.output_ids[send_token_offset:])
|
||||
output_ids.append(req.output_ids[send_token_offset:])
|
||||
req.send_token_offset = len(req.output_ids)
|
||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||
spaces_between_special_tokens.append(
|
||||
|
||||
@@ -750,7 +750,11 @@ class TokenizerManager:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
if (
|
||||
request is not None
|
||||
and not obj.background
|
||||
and await request.is_disconnected()
|
||||
):
|
||||
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
||||
self.abort_request(obj.rid)
|
||||
# Use exception to kill the whole call stack and asyncio task
|
||||
@@ -805,7 +809,11 @@ class TokenizerManager:
|
||||
if obj.stream:
|
||||
yield out
|
||||
else:
|
||||
if request is not None and await request.is_disconnected():
|
||||
if (
|
||||
request is not None
|
||||
and not obj.background
|
||||
and await request.is_disconnected()
|
||||
):
|
||||
# Abort the request for disconnected requests (non-streaming, running)
|
||||
self.abort_request(obj.rid)
|
||||
# Use exception to kill the whole call stack and asyncio task
|
||||
@@ -1548,8 +1556,17 @@ class TokenizerManager:
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
state.text += recv_obj.output_strs[i]
|
||||
if state.obj.stream:
|
||||
state.output_ids.extend(recv_obj.output_ids[i])
|
||||
output_token_ids = state.output_ids[state.last_output_offset :]
|
||||
state.last_output_offset = len(state.output_ids)
|
||||
else:
|
||||
state.output_ids.extend(recv_obj.output_ids[i])
|
||||
output_token_ids = state.output_ids.copy()
|
||||
|
||||
out_dict = {
|
||||
"text": state.text,
|
||||
"output_ids": output_token_ids,
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
|
||||
@@ -274,6 +274,9 @@ class ServerArgs:
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
|
||||
# For tool server
|
||||
tool_server: Optional[str] = None
|
||||
|
||||
# Deprecated arguments
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
@@ -1916,6 +1919,14 @@ class ServerArgs:
|
||||
help="Disable mmap while loading weight using safetensors.",
|
||||
)
|
||||
|
||||
# For tool server
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
|
||||
)
|
||||
|
||||
# Deprecated arguments
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
|
||||
@@ -41,6 +41,7 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import OrderedDict, defaultdict
|
||||
from contextlib import contextmanager
|
||||
@@ -233,6 +234,10 @@ def is_flashinfer_available():
|
||||
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
|
||||
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user