Files
sglang/python/sglang/srt/entrypoints/context.py
Chang Su 92cc32d9fc Support v1/responses and use harmony in serving_chat (#8837)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
2025-08-06 16:20:34 -07:00

245 lines
7.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copied from vLLM
import json
import logging
from abc import ABC, abstractmethod
from typing import Union
logger = logging.getLogger(__name__)
try:
from mcp import ClientSession
except ImportError:
logger.warning("Ignoring mcp import error")
from openai_harmony import Author, Message, Role, StreamState, TextContent
from sglang.srt.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
)
from sglang.srt.entrypoints.tool import Tool
class ConversationContext(ABC):
@abstractmethod
def append_output(self, output) -> None:
pass
@abstractmethod
async def call_tool(self) -> list[Message]:
pass
@abstractmethod
def need_builtin_tool_call(self) -> bool:
pass
@abstractmethod
def render_for_completion(self) -> list[int]:
pass
class SimpleContext(ConversationContext):
def __init__(self):
self.last_output = None
def append_output(self, output) -> None:
self.last_output = output
def need_builtin_tool_call(self) -> bool:
return False
async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
tool_sessions: dict[str, Union["ClientSession", Tool]],
):
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
# when demo.
self._messages = messages
self.tool_sessions = tool_sessions
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
# TODO
self.num_prompt_tokens = 0
self.num_cached_tokens = 0
self.num_output_tokens = 0
self.num_reasoning_tokens = 0
def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
output_token_ids = output["output_ids"]
# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass
for token_id in output_token_ids:
self.parser.process(token_id)
output_msgs = self.parser.messages
meta_info = output["meta_info"]
if isinstance(meta_info, dict):
if "prompt_token_ids" in meta_info:
self.num_prompt_tokens = meta_info["prompt_tokens"]
if "cached_tokens" in meta_info:
self.num_cached_tokens = meta_info["cached_tokens"]
if "completion_tokens" in meta_info:
self.num_output_tokens += meta_info["completion_tokens"]
else:
output_msgs = output
self._messages.extend(output_msgs)
@property
def messages(self) -> list:
return self._messages
def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
return recipient is not None and (
recipient.startswith("browser.") or recipient.startswith("python")
)
async def call_tool(self) -> list[Message]:
if not self.messages:
return []
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self.tool_sessions["browser"], last_msg
)
elif recipient.startswith("python"):
return await self.call_python_tool(
self.tool_sessions["python"], last_msg
)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)
async def call_search_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
"code": last_msg.content[0].text,
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name="python")
return [
Message(
author=author,
content=[content],
channel=last_msg.channel,
recipient=Role.ASSISTANT,
)
]
class StreamingHarmonyContext(HarmonyContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_output = None
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
@property
def messages(self) -> list:
return self.parser.messages
def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
# RequestOutput from SGLang with outputs
output_token_ids = output["output_ids"]
# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
# Find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass
for token_id in output_token_ids:
self.parser.process(token_id)
else:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
msg = output[0]
# Sometimes the recipient is not set for tool messages,
# so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant"
toks = self.encoding.render(msg)
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START
def is_assistant_action_turn(self) -> bool:
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
def render_for_completion(self) -> list[int]:
# now this list of tokens as next turn's starting tokens
# `<|start|>assistant``,
# we need to process them in parser.
rendered_tokens = super().render_for_completion()
last_n = -1
to_process = []
while rendered_tokens[last_n] != self.last_tok:
to_process.append(rendered_tokens[last_n])
last_n -= 1
for tok in reversed(to_process):
self.parser.process(tok)
return rendered_tokens