From 92cc32d9fcee0ab6c020891e723f27daeef232ee Mon Sep 17 00:00:00 2001 From: Chang Su Date: Wed, 6 Aug 2025 16:20:34 -0700 Subject: [PATCH] Support v1/responses and use harmony in serving_chat (#8837) Signed-off-by: Xinyuan Tong Signed-off-by: Xinyuan Tong Co-authored-by: Xinyuan Tong Co-authored-by: Xinyuan Tong --- python/pyproject.toml | 3 +- python/sglang/srt/entrypoints/context.py | 244 ++++ .../sglang/srt/entrypoints/harmony_utils.py | 370 +++++ python/sglang/srt/entrypoints/http_server.py | 69 + .../sglang/srt/entrypoints/openai/protocol.py | 228 ++- .../srt/entrypoints/openai/serving_chat.py | 295 +++- .../entrypoints/openai/serving_responses.py | 1273 +++++++++++++++++ .../srt/entrypoints/openai/tool_server.py | 174 +++ python/sglang/srt/entrypoints/tool.py | 87 ++ .../srt/function_call/harmony_tool_parser.py | 130 ++ .../srt/managers/detokenizer_manager.py | 2 +- python/sglang/srt/managers/io_struct.py | 6 + .../scheduler_output_processor_mixin.py | 3 +- .../sglang/srt/managers/tokenizer_manager.py | 21 +- python/sglang/srt/server_args.py | 11 + python/sglang/srt/utils.py | 5 + 16 files changed, 2878 insertions(+), 43 deletions(-) create mode 100644 python/sglang/srt/entrypoints/context.py create mode 100644 python/sglang/srt/entrypoints/harmony_utils.py create mode 100644 python/sglang/srt/entrypoints/openai/serving_responses.py create mode 100644 python/sglang/srt/entrypoints/openai/tool_server.py create mode 100644 python/sglang/srt/entrypoints/tool.py create mode 100644 python/sglang/srt/function_call/harmony_tool_parser.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 3e63ed50c..753d281be 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -29,6 +29,7 @@ runtime_common = [ "modelscope", "msgspec", "ninja", + "openai-harmony==0.0.3", "orjson", "outlines==0.1.11", "packaging", @@ -96,7 +97,7 @@ srt_cpu = ["sglang[runtime_common]", "einops"] # https://vllm-ascend.readthedocs.io/en/latest/installation.html srt_npu = ["sglang[runtime_common]"] -openai = ["openai>=1.0", "tiktoken"] +openai = ["openai>=1.99.1", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] torch_memory_saver = ["torch_memory_saver>=0.0.8"] diff --git a/python/sglang/srt/entrypoints/context.py b/python/sglang/srt/entrypoints/context.py new file mode 100644 index 000000000..0c8bc116d --- /dev/null +++ b/python/sglang/srt/entrypoints/context.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copied from vLLM +import json +import logging +from abc import ABC, abstractmethod +from typing import Union + +logger = logging.getLogger(__name__) + +try: + from mcp import ClientSession +except ImportError: + logger.warning("Ignoring mcp import error") + +from openai_harmony import Author, Message, Role, StreamState, TextContent + +from sglang.srt.entrypoints.harmony_utils import ( + get_encoding, + get_streamable_parser_for_assistant, + render_for_completion, +) +from sglang.srt.entrypoints.tool import Tool + + +class ConversationContext(ABC): + + @abstractmethod + def append_output(self, output) -> None: + pass + + @abstractmethod + async def call_tool(self) -> list[Message]: + pass + + @abstractmethod + def need_builtin_tool_call(self) -> bool: + pass + + @abstractmethod + def render_for_completion(self) -> list[int]: + pass + + +class SimpleContext(ConversationContext): + + def __init__(self): + self.last_output = None + + def append_output(self, output) -> None: + self.last_output = output + + def need_builtin_tool_call(self) -> bool: + return False + + async def call_tool(self) -> list[Message]: + raise NotImplementedError("Should not be called.") + + def render_for_completion(self) -> list[int]: + raise NotImplementedError("Should not be called.") + + +class HarmonyContext(ConversationContext): + + def __init__( + self, + messages: list, + tool_sessions: dict[str, Union["ClientSession", Tool]], + ): + # TODO: Remove the hack of Union[ClientSession, Tool] by using MCP + # when demo. + self._messages = messages + self.tool_sessions = tool_sessions + + self.parser = get_streamable_parser_for_assistant() + self.num_init_messages = len(messages) + # TODO + self.num_prompt_tokens = 0 + self.num_cached_tokens = 0 + self.num_output_tokens = 0 + self.num_reasoning_tokens = 0 + + def append_output(self, output) -> None: + if isinstance(output, dict) and "output_ids" in output: + output_token_ids = output["output_ids"] + + # TODO: REMOVE here: + # Very hacky, find the first occurrence of token 200006 and cut from there + try: + start_index = output_token_ids.index(200006) + output_token_ids = output_token_ids[start_index:] + except ValueError: + pass + + for token_id in output_token_ids: + self.parser.process(token_id) + output_msgs = self.parser.messages + + meta_info = output["meta_info"] + + if isinstance(meta_info, dict): + if "prompt_token_ids" in meta_info: + self.num_prompt_tokens = meta_info["prompt_tokens"] + if "cached_tokens" in meta_info: + self.num_cached_tokens = meta_info["cached_tokens"] + if "completion_tokens" in meta_info: + self.num_output_tokens += meta_info["completion_tokens"] + + else: + output_msgs = output + + self._messages.extend(output_msgs) + + @property + def messages(self) -> list: + return self._messages + + def need_builtin_tool_call(self) -> bool: + last_msg = self.messages[-1] + recipient = last_msg.recipient + return recipient is not None and ( + recipient.startswith("browser.") or recipient.startswith("python") + ) + + async def call_tool(self) -> list[Message]: + if not self.messages: + return [] + last_msg = self.messages[-1] + recipient = last_msg.recipient + if recipient is not None: + if recipient.startswith("browser."): + return await self.call_search_tool( + self.tool_sessions["browser"], last_msg + ) + elif recipient.startswith("python"): + return await self.call_python_tool( + self.tool_sessions["python"], last_msg + ) + raise ValueError("No tool call found") + + def render_for_completion(self) -> list[int]: + return render_for_completion(self.messages) + + async def call_search_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + tool_name = last_msg.recipient.split(".")[1] + args = json.loads(last_msg.content[0].text) + result = await tool_session.call_tool(tool_name, args) + result_str = result.content[0].text + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name=last_msg.recipient) + return [Message(author=author, content=[content], recipient=Role.ASSISTANT)] + + async def call_python_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + param = { + "code": last_msg.content[0].text, + } + result = await tool_session.call_tool("python", param) + result_str = result.content[0].text + + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name="python") + + return [ + Message( + author=author, + content=[content], + channel=last_msg.channel, + recipient=Role.ASSISTANT, + ) + ] + + +class StreamingHarmonyContext(HarmonyContext): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.last_output = None + + self.parser = get_streamable_parser_for_assistant() + self.encoding = get_encoding() + self.last_tok = None + + @property + def messages(self) -> list: + return self.parser.messages + + def append_output(self, output) -> None: + if isinstance(output, dict) and "output_ids" in output: + # RequestOutput from SGLang with outputs + output_token_ids = output["output_ids"] + + # TODO: REMOVE here: + # Very hacky, find the first occurrence of token 200006 and cut from there + # Find the first occurrence of token 200006 and cut from there + try: + start_index = output_token_ids.index(200006) + output_token_ids = output_token_ids[start_index:] + except ValueError: + pass + + for token_id in output_token_ids: + self.parser.process(token_id) + + else: + # Handle the case of tool output in direct message format + assert len(output) == 1, "Tool output should be a single message" + msg = output[0] + # Sometimes the recipient is not set for tool messages, + # so we set it to "assistant" + if msg.author.role == Role.TOOL and msg.recipient is None: + msg.recipient = "assistant" + toks = self.encoding.render(msg) + for tok in toks: + self.parser.process(tok) + self.last_tok = toks[-1] + + def is_expecting_start(self) -> bool: + return self.parser.state == StreamState.EXPECT_START + + def is_assistant_action_turn(self) -> bool: + return self.last_tok in self.encoding.stop_tokens_for_assistant_actions() + + def render_for_completion(self) -> list[int]: + # now this list of tokens as next turn's starting tokens + # `<|start|>assistant``, + # we need to process them in parser. + rendered_tokens = super().render_for_completion() + + last_n = -1 + to_process = [] + while rendered_tokens[last_n] != self.last_tok: + to_process.append(rendered_tokens[last_n]) + last_n -= 1 + for tok in reversed(to_process): + self.parser.process(tok) + + return rendered_tokens diff --git a/python/sglang/srt/entrypoints/harmony_utils.py b/python/sglang/srt/entrypoints/harmony_utils.py new file mode 100644 index 000000000..635c37187 --- /dev/null +++ b/python/sglang/srt/entrypoints/harmony_utils.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import datetime +import json +from collections.abc import Iterable +from typing import Literal, Optional, Union + +from openai.types.responses import ( + ResponseOutputItem, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.responses.response_function_web_search import ( + ActionFind, + ActionOpenPage, + ActionSearch, + ResponseFunctionWebSearch, +) +from openai.types.responses.response_reasoning_item import ( + Content as ResponseReasoningTextContent, +) +from openai.types.responses.tool import Tool +from openai_harmony import ( + Author, + Conversation, + DeveloperContent, + HarmonyEncodingName, + Message, + ReasoningEffort, + Role, + StreamableParser, + SystemContent, + TextContent, + ToolDescription, + load_harmony_encoding, +) + +from sglang.srt.entrypoints.openai.protocol import ResponseInputOutputItem +from sglang.srt.utils import random_uuid + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +_harmony_encoding = None + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def get_system_message( + model_identity: Optional[str] = None, + reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, + start_date: Optional[str] = None, + browser_description: Optional[str] = None, + python_description: Optional[str] = None, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort] + ) + if start_date is None: + start_date = datetime.datetime.now().strftime("%Y-%m-%d") + sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def get_developer_message( + instructions: Optional[str] = None, tools: Optional[list[Tool]] = None +) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools = [] + for tool in tools: + if tool.type in ("web_search_preview", "code_interpreter"): + # These are built-in tools that are added to the system message. + pass + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions + ) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def parse_response_input( + response_msg: ResponseInputOutputItem, + prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]], +) -> Message: + if not isinstance(response_msg, dict): + response_msg = response_msg.model_dump() + if "type" not in response_msg or response_msg["type"] == "message": + role = response_msg["role"] + content = response_msg["content"] + if role == "system": + # User is trying to set a system message. Change it to: + # <|start|>developer<|message|># Instructions + # {instructions}<|end|> + role = "developer" + text_prefix = "Instructions:\n" + else: + text_prefix = "" + if isinstance(content, str): + msg = Message.from_role_and_content(role, text_prefix + content) + else: + contents = [TextContent(text=text_prefix + c["text"]) for c in content] + msg = Message.from_role_and_contents(role, contents) + elif response_msg["type"] == "function_call_output": + call_id = response_msg["call_id"] + call_response: Optional[ResponseFunctionToolCall] = None + for prev_response in reversed(prev_responses): + if ( + isinstance(prev_response, ResponseFunctionToolCall) + and prev_response.call_id == call_id + ): + call_response = prev_response + break + if call_response is None: + raise ValueError(f"No call message found for {call_id}") + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{call_response.name}"), + response_msg["output"], + ) + elif response_msg["type"] == "reasoning": + content = response_msg["content"] + assert len(content) == 1 + msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) + elif response_msg["type"] == "function_call": + msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"]) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{response_msg['name']}") + msg = msg.with_content_type("json") + else: + raise ValueError(f"Unknown input type: {response_msg['type']}") + return msg + + +def parse_response_output(output: ResponseOutputItem) -> Message: + if isinstance(output, ResponseOutputMessage): + role = output.role + contents = [TextContent(text=c.text) for c in output.content] + msg = Message.from_role_and_contents(role, contents) + return msg + elif isinstance(output, ResponseFunctionToolCall): + msg = Message.from_role_and_content(Role.ASSISTANT, output.arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(output.name) + msg = msg.with_content_type("json") + return msg + else: + raise ValueError(f"Unknown output type: {type(output)}") + + +def parse_chat_input(chat_msg) -> Message: + role = chat_msg.role + content = chat_msg.content + if isinstance(content, str): + contents = [TextContent(text=content)] + else: + # TODO: Support refusal. + contents = [TextContent(text=c.text) for c in content] + msg = Message.from_role_and_contents(role, contents) + return msg + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT + ) + return token_ids + + +def get_stop_tokens_for_assistant_actions() -> list[int]: + return get_encoding().stop_tokens_for_assistant_actions() + + +def get_streamable_parser_for_assistant() -> StreamableParser: + return StreamableParser(get_encoding(), role=Role.ASSISTANT) + + +def parse_output_message(message: Message): + if message.author.role != "assistant": + # This is a message from a tool to the assistant (e.g., search result). + # Don't include it in the final output for now. This aligns with + # OpenAI's behavior on models like o4-mini. + return [] + + output_items = [] + recipient = message.recipient + if recipient is not None and recipient.startswith("browser."): + if len(message.content) != 1: + raise ValueError("Invalid number of contents in browser message") + content = message.content[0] + browser_call = json.loads(content.text) + # TODO: translate to url properly! + if recipient == "browser.search": + action = ActionSearch( + query=f"cursor:{browser_call.get('query', '')}", type="search" + ) + elif recipient == "browser.open": + action = ActionOpenPage( + url=f"cursor:{browser_call.get('url', '')}", type="open_page" + ) + elif recipient == "browser.find": + action = ActionFind( + pattern=browser_call["pattern"], + url=f"cursor:{browser_call.get('url', '')}", + type="find", + ) + else: + raise ValueError(f"Unknown browser action: {recipient}") + web_search_item = ResponseFunctionWebSearch( + id=f"ws_{random_uuid()}", + action=action, + status="completed", + type="web_search_call", + ) + output_items.append(web_search_item) + elif message.channel == "analysis": + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + type="reasoning", + summary=[], + content=[ + ResponseReasoningTextContent( + text=content.text, type="reasoning_text" + ) + ], + status=None, + ) + output_items.append(reasoning_item) + elif message.channel == "commentary": + if message.recipient.startswith("functions."): + function_name = message.recipient.split(".")[-1] + for content in message.content: + random_id = random_uuid() + response_item = ResponseFunctionToolCall( + arguments=content.text, + call_id=f"call_{random_id}", + type="function_call", + name=function_name, + id=f"ft_{random_id}", + ) + output_items.append(response_item) + elif message.recipient.startswith("python") or message.recipient.startswith( + "browser" + ): + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + type="reasoning", + summary=[], + content=[ + ResponseReasoningTextContent( + text=content.text, type="reasoning_text" + ) + ], + status=None, + ) + output_items.append(reasoning_item) + else: + raise ValueError(f"Unknown recipient: {message.recipient}") + elif message.channel == "final": + contents = [] + for content in message.content: + output_text = ResponseOutputText( + text=content.text, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + contents.append(output_text) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=contents, + role=message.author.role, + status="completed", + type="message", + ) + output_items.append(text_item) + else: + raise ValueError(f"Unknown channel: {message.channel}") + return output_items + + +def parse_remaining_state(parser: StreamableParser): + if not parser.current_content: + return [] + if parser.current_role != Role.ASSISTANT: + return [] + current_recipient = parser.current_recipient + if current_recipient is not None and current_recipient.startswith("browser."): + return [] + + if parser.current_channel == "analysis": + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + type="reasoning", + summary=[], + content=[ + ResponseReasoningTextContent( + text=parser.current_content, type="reasoning_text" + ) + ], + status=None, + ) + return [reasoning_item] + elif parser.current_channel == "final": + output_text = ResponseOutputText( + content=[ + ResponseReasoningTextContent( + text=parser.current_content, type="reasoning_text" + ) + ], + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + status="completed", + type="message", + ) + return [text_item] + return [] + + +def parse_output_into_messages(token_ids: Iterable[int]): + parser = get_streamable_parser_for_assistant() + for token_id in token_ids: + parser.process(token_id) + return parser diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 180d33820..c4d36088f 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -32,6 +32,7 @@ from typing import AsyncIterator, Callable, Dict, Optional setattr(threading, "_register_atexit", lambda *args, **kwargs: None) from contextlib import asynccontextmanager +from typing import AsyncGenerator import numpy as np import orjson @@ -56,6 +57,7 @@ from sglang.srt.entrypoints.openai.protocol import ( ErrorResponse, ModelCard, ModelList, + ResponsesRequest, ScoringRequest, V1RerankReqInput, ) @@ -147,6 +149,37 @@ async def lifespan(fast_api_app: FastAPI): ) server_args: ServerArgs = fast_api_app.server_args + + tool_server = None + if server_args.tool_server == "demo": + from sglang.srt.entrypoints.openai.tool_server import DemoToolServer + + tool_server = DemoToolServer() + elif server_args.tool_server: + from sglang.srt.entrypoints.openai.tool_server import MCPToolServer + + tool_server = MCPToolServer() + await tool_server.add_tool_server(server_args.tool_server) + + try: + from sglang.srt.entrypoints.openai.serving_responses import ( + OpenAIServingResponses, + ) + + fast_api_app.state.openai_serving_responses = OpenAIServingResponses( + _global_state.tokenizer_manager, + _global_state.template_manager, + enable_prompt_tokens_details=True, + enable_force_include_usage=True, + tool_server=tool_server, + ) + except Exception as e: + # print stack trace + import traceback + + traceback.print_exc() + logger.warning(f"Can not initialize OpenAIServingResponses, error: {e}") + if server_args.warmups is not None: await execute_warmups( server_args.disaggregation_mode, @@ -843,6 +876,42 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request): ) +@app.post("/v1/responses", dependencies=[Depends(validate_json_request)]) +async def v1_responses_request(request: dict, raw_request: Request): + """Endpoint for the responses API with reasoning support.""" + + request_obj = ResponsesRequest(**request) + result = await raw_request.app.state.openai_serving_responses.create_responses( + request_obj, raw_request + ) + + # Handle streaming responses + if isinstance(result, AsyncGenerator): + return StreamingResponse( + result, + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + return result + + +@app.get("/v1/responses/{response_id}") +async def v1_retrieve_responses(response_id: str, raw_request: Request): + """Retrieve a response by ID.""" + return await raw_request.app.state.openai_serving_responses.retrieve_responses( + response_id + ) + + +@app.post("/v1/responses/{response_id}/cancel") +async def v1_cancel_responses(response_id: str, raw_request: Request): + """Cancel a background response.""" + return await raw_request.app.state.openai_serving_responses.cancel_responses( + response_id + ) + + @app.api_route( "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)] ) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index f7596c975..fb12eee1c 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -14,9 +14,18 @@ """Pydantic models for OpenAI API protocol""" import time +import uuid from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, TypeAlias, Union +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseInputItemParam, + ResponseOutputItem, + ResponseReasoningItem, +) +from openai.types.responses.response import ToolChoice +from openai.types.responses.tool import Tool from pydantic import ( BaseModel, Field, @@ -84,6 +93,7 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 # only used to return cached tokens when --enable-cache-report is set prompt_tokens_details: Optional[Dict[str, int]] = None + reasoning_tokens: Optional[int] = 0 class StreamOptions(BaseModel): @@ -428,6 +438,13 @@ class ChatCompletionRequest(BaseModel): default="auto", examples=["none"] ) # noqa return_hidden_states: bool = False + reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field( + default="medium", + description="Constrains effort on reasoning for reasoning models. " + "'low' is the least effort, 'high' is the most effort. Reducing reasoning effort can " + "result in faster responses and fewer tokens used on reasoning in a response. " + "Currently only supported for OpenAI models.", + ) @model_validator(mode="before") @classmethod @@ -619,6 +636,196 @@ OpenAIServingRequest = Union[ ] +# Response API protocol definitions +class ResponseReasoningParam(BaseModel): + """Reasoning parameters for responses.""" + + effort: Optional[Literal["low", "medium", "high"]] = Field( + default="medium", + description="Constrains effort on reasoning for reasoning models.", + ) + + +class ResponseTool(BaseModel): + """Tool definition for responses.""" + + type: Literal["web_search_preview", "code_interpreter"] = Field( + description="Type of tool to enable" + ) + + +ResponseInputOutputItem: TypeAlias = Union[ + ResponseInputItemParam, + "ResponseReasoningItem", + ResponseFunctionToolCall, +] + + +class ResponsesRequest(BaseModel): + """Request body for v1/responses endpoint.""" + + # Core OpenAI API fields (ordered by official documentation) + background: Optional[bool] = False + include: Optional[ + List[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ] + ] + ] = None + input: Union[str, List[ResponseInputOutputItem]] + instructions: Optional[str] = None + max_output_tokens: Optional[int] = None + max_tool_calls: Optional[int] = None + metadata: Optional[Dict[str, Any]] = None + model: Optional[str] = None # Made optional to match vLLM + parallel_tool_calls: Optional[bool] = True + previous_response_id: Optional[str] = None + reasoning: Optional[ResponseReasoningParam] = None + service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto" + store: Optional[bool] = True + stream: Optional[bool] = False + temperature: Optional[float] = None + tool_choice: Literal["auto", "required", "none"] = "auto" + tools: List[ResponseTool] = Field(default_factory=list) + top_logprobs: Optional[int] = 0 + top_p: Optional[float] = None + truncation: Optional[Literal["auto", "disabled"]] = "disabled" + user: Optional[str] = None + + # Extra SGLang parameters + request_id: str = Field( + default_factory=lambda: f"resp_{uuid.uuid4().hex}", + description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.", + ) + priority: int = Field(default=0, description="Request priority") + + # SGLang-specific sampling parameters + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + stop: Optional[Union[str, List[str]]] = None + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + + # Default sampling parameters + _DEFAULT_SAMPLING_PARAMS = { + "temperature": 0.7, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + "repetition_penalty": 1.0, + } + + def to_sampling_params( + self, default_max_tokens: int, default_params: Optional[Dict] = None + ) -> Dict[str, Any]: + """Convert to sampling parameters for generation.""" + if default_params is None: + default_params = {} + + # Use max_output_tokens if available, otherwise use max_tokens for backwards compatibility + if self.max_output_tokens is not None: + max_tokens = min(self.max_output_tokens, default_max_tokens) + else: + max_tokens = default_max_tokens + + # Avoid exceed the context length by minus 1 token + max_tokens -= 1 + + # Get parameters with defaults + temperature = self.temperature + if temperature is None: + temperature = default_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) + + top_p = self.top_p + if top_p is None: + top_p = default_params.get("top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + + params = { + "max_new_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "stop": self.stop, + "top_k": self.top_k, + "min_p": self.min_p, + "repetition_penalty": self.repetition_penalty, + } + + # Apply any additional default parameters + for key, value in default_params.items(): + if key not in params or params[key] is None: + params[key] = value + + return params + + +class PromptTokenUsageInfo(BaseModel): + """Prompt token usage details.""" + + cached_tokens: int = 0 + + +class ResponsesResponse(BaseModel): + """Response body for v1/responses endpoint.""" + + id: str = Field(default_factory=lambda: f"resp_{time.time()}") + object: Literal["response"] = "response" + created_at: int = Field(default_factory=lambda: int(time.time())) + model: str + + output: List[ + Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall] + ] = Field(default_factory=list) + status: Literal["queued", "in_progress", "completed", "failed", "cancelled"] + usage: Optional[UsageInfo] = None + parallel_tool_calls: bool = True + tool_choice: str = "auto" + tools: List[ResponseTool] = Field(default_factory=list) + + @classmethod + def from_request( + cls, + request: ResponsesRequest, + sampling_params: Any, + model_name: str, + created_time: int, + output: List[ + Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall] + ], + status: str, + usage: Optional[UsageInfo], + ) -> "ResponsesResponse": + """Create a response from a request.""" + return cls( + id=request.request_id, + created_at=created_time, + model=model_name, + output=output, + status=status, + usage=usage, + parallel_tool_calls=request.parallel_tool_calls or True, + tool_choice=request.tool_choice, + tools=request.tools, + ) + + +class RequestResponseMetadata(BaseModel): + """Metadata for request/response tracking.""" + + request_id: str + final_usage_info: Optional[UsageInfo] = None + + @dataclass class MessageProcessingResult: """Result of processing chat messages and applying templates. @@ -645,3 +852,22 @@ class MessageProcessingResult: modalities: List[str] stop: List[str] tool_call_constraint: Optional[Any] = None + + +class ResponseReasoningTextContent(BaseModel): + text: str + type: Literal["reasoning_text"] = "reasoning_text" + + +class ResponseReasoningItem(BaseModel): + id: str + content: list[ResponseReasoningTextContent] = Field(default_factory=list) + summary: list = Field(default_factory=list) + type: Literal["reasoning"] = "reasoning" + encrypted_content: Optional[str] = None + status: Optional[Literal["in_progress", "completed", "incomplete"]] + + +ResponseInputOutputItem: TypeAlias = Union[ + ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall +] diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index a7beccf93..c8918ed4c 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -7,8 +7,18 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union from fastapi import Request from fastapi.responses import ORJSONResponse, StreamingResponse +from openai_harmony import Message as OpenAIMessage from sglang.srt.conversation import generate_chat_conv +from sglang.srt.entrypoints.harmony_utils import ( + get_developer_message, + get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, + get_system_message, + parse_chat_input, + parse_output_into_messages, + render_for_completion, +) from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -51,6 +61,26 @@ class OpenAIServingChat(OpenAIServingBase): ): super().__init__(tokenizer_manager) self.template_manager = template_manager + self.use_harmony = ( + self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss" + ) + + if self.use_harmony: + from sglang.srt.function_call.harmony_tool_parser import ( + HarmonyToolCallParser, + ) + + self.harmony_tool_parser = HarmonyToolCallParser() + + # NOTE While OpenAI's chat completion API supports browsing + # for some models, currently vLLM doesn't support it. Please use the + # Responses API instead. + self.supports_browsing = False + self.browser_tool = None + # NOTE: Chat completion API does not support code interpreter. + # Please use the Responses API instead. + self.supports_code_interpreter = False + self.python_tool = None def _request_id_prefix(self) -> str: return "chatcmpl-" @@ -77,41 +107,66 @@ class OpenAIServingChat(OpenAIServingBase): is_multimodal = self.tokenizer_manager.model_config.is_multimodal # Process messages and apply chat template - processed_messages = self._process_messages(request, is_multimodal) + if not self.use_harmony: + processed_messages = self._process_messages(request, is_multimodal) - # Build sampling parameters - sampling_params = self._build_sampling_params( - request, processed_messages.stop, processed_messages.tool_call_constraint - ) + # Build sampling parameters + sampling_params = self._build_sampling_params( + request, + processed_messages.stop, + processed_messages.tool_call_constraint, + ) - # Handle single vs multiple requests - if is_multimodal: - prompt_kwargs = {"text": processed_messages.prompt} - else: - if isinstance(processed_messages.prompt_ids, str): - prompt_kwargs = {"text": processed_messages.prompt_ids} + # Handle single vs multiple requests + if is_multimodal: + prompt_kwargs = {"text": processed_messages.prompt} else: - prompt_kwargs = {"input_ids": processed_messages.prompt_ids} + if isinstance(processed_messages.prompt_ids, str): + prompt_kwargs = {"text": processed_messages.prompt_ids} + else: + prompt_kwargs = {"input_ids": processed_messages.prompt_ids} - adapted_request = GenerateReqInput( - **prompt_kwargs, - image_data=processed_messages.image_data, - video_data=processed_messages.video_data, - audio_data=processed_messages.audio_data, - sampling_params=sampling_params, - return_logprob=request.logprobs, - logprob_start_len=-1, - top_logprobs_num=request.top_logprobs or 0, - stream=request.stream, - return_text_in_logprobs=True, - modalities=processed_messages.modalities, - lora_path=request.lora_path, - bootstrap_host=request.bootstrap_host, - bootstrap_port=request.bootstrap_port, - bootstrap_room=request.bootstrap_room, - return_hidden_states=request.return_hidden_states, - rid=request.rid, - ) + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=processed_messages.image_data, + video_data=processed_messages.video_data, + audio_data=processed_messages.audio_data, + sampling_params=sampling_params, + return_logprob=request.logprobs, + logprob_start_len=-1, + top_logprobs_num=request.top_logprobs or 0, + stream=request.stream, + return_text_in_logprobs=True, + modalities=processed_messages.modalities, + lora_path=request.lora_path, + bootstrap_host=request.bootstrap_host, + bootstrap_port=request.bootstrap_port, + bootstrap_room=request.bootstrap_room, + return_hidden_states=request.return_hidden_states, + rid=request.rid, + ) + else: + processed_messages, prompt_ids = self._make_request_with_harmony(request) + + adapted_request = GenerateReqInput( + input_ids=prompt_ids, + sampling_params=self._build_sampling_params( + request, + request.stop, + tool_call_constraint=None, + ), + stream=request.stream, + return_logprob=request.logprobs, + logprob_start_len=-1, + top_logprobs_num=request.top_logprobs or 0, + return_text_in_logprobs=True, + lora_path=request.lora_path, + bootstrap_host=request.bootstrap_host, + bootstrap_port=request.bootstrap_port, + bootstrap_room=request.bootstrap_room, + return_hidden_states=request.return_hidden_states, + rid=request.rid, + ) return adapted_request, request @@ -402,6 +457,12 @@ class OpenAIServingChat(OpenAIServingBase): cached_tokens = {} hidden_states = {} + # Harmony tracking + if self.use_harmony: + harmony_parsers = [ + get_streamable_parser_for_assistant() for _ in range(request.n) + ] + try: async for content in self.tokenizer_manager.generate_request( adapted_request, raw_request @@ -449,14 +510,57 @@ class OpenAIServingChat(OpenAIServingBase): yield f"data: {chunk.model_dump_json()}\n\n" # Process content delta - stream_buffer = stream_buffers.get(index, "") - delta = content["text"][len(stream_buffer) :] - stream_buffers[index] = stream_buffer + delta + if self.use_harmony: + harmony_parser = harmony_parsers[index] + + new_token_ids = content["output_ids"] + for token_id in new_token_ids: + harmony_parser.process(token_id) + + is_final = harmony_parser.current_channel == "final" + is_analysis = harmony_parser.current_channel == "analysis" + delta = harmony_parser.last_content_delta or "" + + if is_analysis: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(reasoning_content=delta), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + continue + + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=delta if delta else None), + finish_reason=None, + matched_stop=None, + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + continue + else: + stream_buffer = stream_buffers.get(index, "") + delta = content["text"][len(stream_buffer) :] + stream_buffers[index] = stream_buffer + delta # Handle reasoning content if ( self.tokenizer_manager.server_args.reasoning_parser and request.separate_reasoning + and not self.use_harmony ): reasoning_text, delta = self._process_reasoning_stream( index, delta, reasoning_parser_dict, content, request @@ -475,8 +579,27 @@ class OpenAIServingChat(OpenAIServingBase): ) yield f"data: {chunk.model_dump_json()}\n\n" + if self.use_harmony and not is_final: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(reasoning_content=delta), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + # Handle tool calls - if request.tool_choice != "none" and request.tools: + # TODO: support tool call parsing for harmony + if ( + request.tool_choice != "none" + and request.tools + and not self.use_harmony + ): async for chunk in self._process_tool_call_stream( index, delta, @@ -502,7 +625,7 @@ class OpenAIServingChat(OpenAIServingBase): if delta: choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(content=delta if delta else None), + delta=DeltaMessage(content=delta), finish_reason=None, matched_stop=None, logprobs=choice_logprobs, @@ -640,6 +763,76 @@ class OpenAIServingChat(OpenAIServingBase): finish_reason = ret_item["meta_info"]["finish_reason"] text = ret_item["text"] + output_ids = ret_item["output_ids"] + + if self.use_harmony: + parser = parse_output_into_messages(output_ids) + output_msgs = parser.messages + if len(output_msgs) == 0: + # The generation has stopped during reasoning. + is_tool_call = False + reasoning_content = parser.current_content + final_content = None + elif len(output_msgs) == 1: + # The generation has stopped during final message. + is_tool_call = False + reasoning_content = output_msgs[0].content[0].text + final_content = parser.current_content + else: + if len(output_msgs) != 2: + raise ValueError( + "Expected 2 output messages (reasoning and final), " + f"but got {len(output_msgs)}." + ) + reasoning_msg, final_msg = output_msgs + reasoning_content = reasoning_msg.content[0].text + final_content = final_msg.content[0].text + is_tool_call = final_msg.recipient is not None + + if is_tool_call: + # Extract tool call information from final message + tool_call = ( + self.harmony_tool_parser.extract_tool_calls_from_message( + final_msg + ) + ) + tool_calls = [tool_call] if tool_call else [] + + message = ChatMessage( + role="assistant", + reasoning_content=reasoning_content, + content=None, # Tool calls don't have regular content + tool_calls=tool_calls, + ) + else: + # Normal message + message = ChatMessage( + role="assistant", + reasoning_content=reasoning_content, + content=final_content, + ) + + if is_tool_call: + finish_reason_type = "tool_calls" + elif finish_reason: + finish_reason_type = ( + finish_reason["type"] if finish_reason else "stop" + ) + else: + finish_reason_type = "stop" + choice_data = ChatCompletionResponseChoice( + index=idx, + message=message, + logprobs=choice_logprobs, + finish_reason=finish_reason_type, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + ) + choices.append(choice_data) + continue # Handle reasoning content reasoning_text = None @@ -978,3 +1171,33 @@ class OpenAIServingChat(OpenAIServingBase): return f"data: {chunk.model_dump_json()}\n\n" return None + + def _make_request_with_harmony( + self, + request: ChatCompletionRequest, + ): + messages: list[OpenAIMessage] = [] + + # Add system message. + # In Chat Completion API, browsing is enabled by default if the model + # supports it. + assert not self.supports_browsing + assert not self.supports_code_interpreter + sys_msg = get_system_message( + reasoning_effort=request.reasoning_effort, + browser_description=None, + python_description=None, + ) + messages.append(sys_msg) + + # Add developer message. + dev_msg = get_developer_message() + messages.append(dev_msg) + + # Add user message. + for chat_msg in request.messages: + messages.append(parse_chat_input(chat_msg)) + + # Render prompt token ids. + prompt_token_ids = render_for_completion(messages) + return messages, prompt_token_ids diff --git a/python/sglang/srt/entrypoints/openai/serving_responses.py b/python/sglang/srt/entrypoints/openai/serving_responses.py new file mode 100644 index 000000000..a9efe4f3b --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_responses.py @@ -0,0 +1,1273 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vLLM's OpenAIServingResponses +"""Handler for /v1/responses requests""" + +import asyncio +import copy +import json +import logging +import time +from contextlib import AsyncExitStack +from http import HTTPStatus +from typing import Any, AsyncGenerator, AsyncIterator, Optional, Union + +import jinja2 +import openai.types.responses as openai_responses_types +from fastapi import Request +from fastapi.responses import ORJSONResponse +from openai.types.responses import ( + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.responses.response_reasoning_item import ( + Content as ResponseReasoningTextContent, +) +from openai_harmony import Message as OpenAIMessage + +from sglang.srt.entrypoints.context import ( + ConversationContext, + HarmonyContext, + SimpleContext, + StreamingHarmonyContext, +) +from sglang.srt.entrypoints.harmony_utils import ( + get_developer_message, + get_stop_tokens_for_assistant_actions, + get_system_message, + get_user_message, + parse_output_message, + parse_remaining_state, + parse_response_input, + render_for_completion, +) +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionMessageParam, + ChatCompletionRequest, + PromptTokenUsageInfo, + RequestResponseMetadata, + ResponsesRequest, + ResponsesResponse, + UsageInfo, +) +from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat +from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.template_manager import TemplateManager +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.reasoning_parser import ReasoningParser +from sglang.srt.utils import random_uuid + +logger = logging.getLogger(__name__) + + +class OpenAIServingResponses(OpenAIServingChat): + """Handler for /v1/responses requests""" + + def __init__( + self, + tokenizer_manager: TokenizerManager, + template_manager: TemplateManager, + *, + enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, + tool_server: Optional[ToolServer] = None, + ) -> None: + super().__init__(tokenizer_manager, template_manager) + + # template_manager is already set by parent class + self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser + self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.enable_force_include_usage = enable_force_include_usage + + # Get default sampling params from model config if available + self.default_sampling_params = {} + + self.supports_browsing = ( + tool_server.has_tool("browser") if tool_server else False + ) + self.supports_code_interpreter = ( + tool_server.has_tool("python") if tool_server else False + ) + self.tool_server = tool_server + # Get from model config + self.use_harmony = ( + self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss" + ) + + if self.use_harmony: + # OpenAI models have two EOS-like tokens: <|return|> and <|call|>. + # We need to add them to the stop token ids. + if "stop_token_ids" not in self.default_sampling_params: + self.default_sampling_params["stop_token_ids"] = [] + self.default_sampling_params["stop_token_ids"].extend( + get_stop_tokens_for_assistant_actions() + ) + + # Response storage for background and retrieval operations + # Note: In production, this should use a proper storage backend (Redis, database) + # with TTL/expiration to prevent memory leaks + self.response_store: dict[str, ResponsesResponse] = {} + self.response_store_lock = asyncio.Lock() + + # Message storage for conversation continuity + # Note: In production, this should use a proper storage backend (Redis, database) + # with TTL/expiration to prevent memory leaks + self.msg_store: dict[ + str, Union[list[ChatCompletionMessageParam], list["OpenAIMessage"]] + ] = {} + + self.background_tasks: dict[str, asyncio.Task] = {} + + def _request_id_prefix(self) -> str: + return "resp_" + + async def create_responses( + self, + request: ResponsesRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ResponsesResponse, ORJSONResponse]: + # Validate model + if not self.tokenizer_manager: + return self.create_error_response("Model not loaded") + + # FIXME: If the engine is dead, raise an error + # This is required for the streaming case + + # Handle the previous response ID + prev_response_id = request.previous_response_id + if prev_response_id is not None: + if not prev_response_id.startswith("resp_"): + return self._make_invalid_id_error(prev_response_id) + async with self.response_store_lock: + prev_response = self.response_store.get(prev_response_id) + if prev_response is None: + return self._make_not_found_error(prev_response_id) + else: + prev_response = None + + try: + model_name = request.model + tokenizer = self.tokenizer_manager.tokenizer + + if self.use_harmony: + messages, request_prompts, engine_prompts = ( + self._make_request_with_harmony(request, prev_response) + ) + else: + messages, request_prompts, engine_prompts = await self._make_request( + request, prev_response, tokenizer + ) + + except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(f"{e} {e.__cause__}") + + request_metadata = RequestResponseMetadata(request_id=request.request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + if ( + self.tool_server is not None + and isinstance(self.tool_server, MCPToolServer) + and (request.background or request.stream) + and request.tools + and any( + tool.type in ["web_search_preview", "code_interpreter"] + for tool in request.tools + ) + ): + return self.create_error_response( + "MCP tool server is not supported in background mode and " + "streaming mode" + ) + + # Schedule the request and get the result generator + generators: list[AsyncGenerator[Any, None]] = [] + tool_list = [] + if self.use_harmony: + if self.supports_browsing: + tool_list.append("browser") + if self.supports_code_interpreter: + tool_list.append("python") + async with AsyncExitStack() as exit_stack: + try: + if self.tool_server is not None: + tool_session_ctxs: dict[str, Any] = { + tool_name: exit_stack.enter_async_context( + self.tool_server.get_tool_session(tool_name) + ) + for tool_name in tool_list + } + tool_sessions = {} + for tool_name in tool_list: + tool_sessions[tool_name] = await tool_session_ctxs[tool_name] + else: + assert len(tool_list) == 0 + tool_sessions = {} + for i, engine_prompt in enumerate(engine_prompts): + # Calculate default max tokens from context length minus prompt length + if hasattr(engine_prompt, "__len__"): + prompt_length = len(engine_prompt) + elif isinstance(engine_prompt, list): + prompt_length = len(engine_prompt) + else: + prompt_length = 0 + + context_len = ( + self.tokenizer_manager.model_config.context_len + if hasattr(self.tokenizer_manager.model_config, "context_len") + else 4096 + ) + default_max_tokens = max( + context_len - prompt_length, 512 + ) # Ensure minimum 512 tokens + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params + ) + + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext(messages, tool_sessions) + else: + context = HarmonyContext(messages, tool_sessions) + else: + context = SimpleContext() + + # Create GenerateReqInput for SGLang + adapted_request = GenerateReqInput( + input_ids=engine_prompt, + sampling_params=sampling_params, + stream=request.stream, + rid=request.request_id, + background=request.background, + ) + + generator = self._generate_with_builtin_tools( + request.request_id, + request_prompts[i], + adapted_request, + sampling_params, + context, + raw_request=raw_request, + priority=request.priority, + ) + generators.append(generator) + except ValueError as e: + return self.create_error_response(str(e)) + + assert len(generators) == 1 + (result_generator,) = generators + + # Store the input messages + if request.store: + self.msg_store[request.request_id] = messages + + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, + ) + async with self.response_store_lock: + self.response_store[response.id] = response + + # Run the request in the background + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) + + # For cleanup + self.background_tasks[response.id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response.id, None) + ) + return response + + if request.stream: + return self.responses_stream_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + try: + result: Union[ORJSONResponse, ResponsesResponse] = ( + await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + ) + return result + except Exception as e: + return self.create_error_response(str(e)) + return self.create_error_response("Unknown error") + + async def _make_request( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + tokenizer: Any, + ): + # Construct the input messages + messages = self._construct_input_messages(request, prev_response) + + # Follow SGLang's pattern: create a ChatCompletionRequest and process messages + try: + # Convert ResponsesRequest to ChatCompletionRequest for processing + chat_request = ChatCompletionRequest( + model=request.model, + messages=messages, + stream=request.stream, + ) + + # Follow SGLang's _process_messages pattern + is_multimodal = self.tokenizer_manager.model_config.is_multimodal + processed_messages = self._process_messages(chat_request, is_multimodal) + + # Extract the results + if is_multimodal: + request_prompts = [processed_messages.prompt] + engine_prompts = [processed_messages.prompt] + else: + request_prompts = [processed_messages.prompt_ids] + engine_prompts = [processed_messages.prompt_ids] + + except Exception as e: + logger.warning(f"Chat processing failed, using fallback: {e}") + # Fallback to simple encoding + prompt_text = "" + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + prompt_text += f"{role}: {content}\n" + prompt_ids = tokenizer.encode(prompt_text) + request_prompts = [prompt_ids] + engine_prompts = [prompt_ids] + + return messages, request_prompts, engine_prompts + + def _make_request_with_harmony( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + ): + if request.tool_choice != "auto": + raise NotImplementedError( + "Only 'auto' tool_choice is supported in " "response API" + ) + messages = self._construct_input_messages_with_harmony(request, prev_response) + prompt_token_ids = render_for_completion(messages) + engine_prompt = prompt_token_ids + return messages, [prompt_token_ids], [engine_prompt] + + async def responses_full_generator( + self, + request: ResponsesRequest, + sampling_params: Any, + result_generator: AsyncIterator[Any], + context: ConversationContext, + model_name: str, + tokenizer: Any, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> Union[ResponsesResponse, ORJSONResponse]: + if created_time is None: + created_time = int(time.time()) + + try: + async for _ in result_generator: + pass + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + return self.create_error_response(str(e)) + + if self.use_harmony: + assert isinstance(context, HarmonyContext) + output = self._make_response_output_items_with_harmony(context) + # TODO: these are all 0 for now! + num_prompt_tokens = context.num_prompt_tokens + num_generated_tokens = context.num_output_tokens + num_cached_tokens = context.num_cached_tokens + num_reasoning_tokens = context.num_reasoning_tokens + else: + assert isinstance(context, SimpleContext) + final_res = context.last_output + assert final_res is not None + + output = self._make_response_output_items( + request, final_res["text"], tokenizer + ) + + # Calculate usage from actual output + if hasattr(final_res, "meta_info"): + num_prompt_tokens = final_res.meta_info.get("prompt_tokens", 0) + num_generated_tokens = final_res.meta_info.get("completion_tokens", 0) + num_cached_tokens = final_res.meta_info.get("cached_tokens", 0) + elif hasattr(final_res, "prompt_token_ids") and hasattr( + final_res, "outputs" + ): + # Fallback calculation if meta_info not available + num_prompt_tokens = ( + len(final_res.prompt_token_ids) if final_res.prompt_token_ids else 0 + ) + num_generated_tokens = ( + len(final_res.outputs[0].token_ids) + if final_res.outputs and final_res.outputs[0].token_ids + else 0 + ) + num_cached_tokens = getattr(final_res, "num_cached_tokens", 0) + num_reasoning_tokens = 0 + else: + # Final fallback + num_prompt_tokens = 0 + num_generated_tokens = 0 + num_cached_tokens = 0 + num_reasoning_tokens = 0 + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + reasoning_tokens=num_reasoning_tokens, + ) + if self.enable_prompt_tokens_details and num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens + ) + request_metadata.final_usage_info = usage + + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=output, + status="completed", + usage=usage, + ) + + if request.store: + async with self.response_store_lock: + stored_response = self.response_store.get(response.id) + # If the response is already cancelled, don't update it + if stored_response is None or stored_response.status != "cancelled": + self.response_store[response.id] = response + + return response + + def _make_response_output_items( + self, + request: ResponsesRequest, + final_output: Any, + tokenizer: Any, + ): + # Handle reasoning parsing if enabled + if self.reasoning_parser: + # Use standard reasoning parser (openai maps to T4Detector internally) + reasoning_parser = ReasoningParser( + model_type=self.reasoning_parser, stream_reasoning=False + ) + reasoning_content, content = reasoning_parser.parse_non_stream(final_output) + else: + reasoning_content = None + content = final_output + + output_items = [] + if reasoning_content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + type="reasoning", + summary=[], + content=[ + ResponseReasoningTextContent( + type="reasoning_text", text=reasoning_content + ), + ], + status=None, + ) + output_items.append(reasoning_item) + if content: + output_text = ResponseOutputText( + text=content, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + message = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + status="completed", + type="message", + ) + output_items.append(message) + return output_items + + def _make_response_output_items_with_harmony( + self, + context: HarmonyContext, + ): + output_items = [] + num_init_messages = context.num_init_messages + for msg in context.messages[num_init_messages:]: + output_items.extend(parse_output_message(msg)) + # Handle the generation stopped in the middle (if any). + last_items = parse_remaining_state(context.parser) + if last_items: + output_items.extend(last_items) + return output_items + + def _construct_input_messages( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse] = None, + ) -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] + if request.instructions: + messages.append( + { + "role": "system", + "content": request.instructions, + } + ) + + # Prepend the conversation history + if prev_response is not None: + # Add the previous messages + prev_msg = self.msg_store[prev_response.id] + messages.extend(prev_msg) + + # Add the previous output + for output_item in prev_response.output: + # NOTE: We skip the reasoning output of the previous response + if isinstance(output_item, ResponseReasoningItem): + continue + for content in output_item.content: + messages.append( + { + "role": "system", + "content": request.instructions, + } + ) + + # Append the new input + # Responses API supports simple text inputs without chat format + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + else: + messages.extend(request.input) # type: ignore + return messages + + def _construct_input_messages_with_harmony( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + ) -> list["OpenAIMessage"]: + messages: list["OpenAIMessage"] = [] + if prev_response is None: + # New conversation. + reasoning_effort = request.reasoning.effort if request.reasoning else None + tool_types = [tool.type for tool in request.tools] + enable_browser = ( + "web_search_preview" in tool_types and self.tool_server is not None + ) + enable_code_interpreter = ( + "code_interpreter" in tool_types and self.tool_server is not None + ) + sys_msg = get_system_message( + reasoning_effort=reasoning_effort, + browser_description=( + self.tool_server.get_tool_description("browser") + if self.tool_server and enable_browser + else None + ), + python_description=( + self.tool_server.get_tool_description("python") + if self.tool_server and enable_code_interpreter + else None + ), + ) + messages.append(sys_msg) + dev_msg = get_developer_message(request.instructions, request.tools) + messages.append(dev_msg) + else: + # Continue the previous conversation. + # FIXME: Currently, request params like reasoning and + # instructions are ignored. + prev_msgs = self.msg_store[prev_response.id] + # Remove the previous chain-of-thoughts if there is a new "final" + # message. + if ( + len(prev_msgs) > 0 + and hasattr(prev_msgs[-1], "channel") + and prev_msgs[-1].channel == "final" + ): # type: ignore[union-attr] + prev_final_msg_idx = -1 + for i in range(len(prev_msgs) - 2, -1, -1): + if ( + hasattr(prev_msgs[i], "channel") + and prev_msgs[i].channel == "final" + ): # type: ignore[union-attr] + prev_final_msg_idx = i + break + recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1 :] + del prev_msgs[prev_final_msg_idx + 1 :] + for msg in recent_turn_msgs: + if ( + hasattr(msg, "channel") and msg.channel != "analysis" + ): # type: ignore[union-attr] + prev_msgs.append(msg) + messages.extend(prev_msgs) + # Append the new input. + # Responses API supports simple text inputs without chat format. + if isinstance(request.input, str): + messages.append(get_user_message(request.input)) + else: + if prev_response is not None: + prev_outputs = copy(prev_response.output) + else: + prev_outputs = [] + for response_msg in request.input: + messages.append(parse_response_input(response_msg, prev_outputs)) + if isinstance(response_msg, ResponseFunctionToolCall): + prev_outputs.append(response_msg) + return messages + + async def _run_background_request( + self, + request: ResponsesRequest, + sampling_params: Any, + result_generator: AsyncIterator[Any], + context: ConversationContext, + model_name: str, + tokenizer: Any, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + *args, + **kwargs, + ): + try: + # Update the status to "in_progress" + async with self.response_store_lock: + stored_response = self.response_store.get(request.request_id) + assert stored_response is not None + stored_response.status = "in_progress" + + response = await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + *args, + **kwargs, + ) + except Exception as e: + logger.exception("Background request failed for %s", request.request_id) + response = self.create_error_response(str(e)) + + if isinstance(response, ORJSONResponse): + # If the request has failed, update the status to "failed" + response_id = request.request_id + async with self.response_store_lock: + stored_response = self.response_store.get(response_id) + assert stored_response is not None + if stored_response.status not in ("completed", "cancelled"): + stored_response.status = "failed" + + async def retrieve_responses( + self, + response_id: str, + ) -> Union[ResponsesResponse, ORJSONResponse]: + if not response_id.startswith("resp_"): + return self._make_invalid_id_error(response_id) + + async with self.response_store_lock: + response = self.response_store.get(response_id) + + if response is None: + return self._make_not_found_error(response_id) + return response + + async def cancel_responses( + self, + response_id: str, + ) -> Union[ResponsesResponse, ORJSONResponse]: + if not response_id.startswith("resp_"): + return self._make_invalid_id_error(response_id) + + async with self.response_store_lock: + response = self.response_store.get(response_id) + if response is None: + return self._make_not_found_error(response_id) + + prev_status = response.status + if prev_status not in ("queued", "in_progress"): + return self.create_error_response( + err_type="invalid_request_error", + message="Cannot cancel a synchronous response.", + ) + + # Update the status to "cancelled" + response.status = "cancelled" + + # Abort the request + if task := self.background_tasks.get(response_id): + task.cancel() + try: + await task + except asyncio.CancelledError: + logger.exception("Background task for %s was cancelled", response_id) + return response + + def _make_invalid_id_error(self, response_id: str): + return self.create_error_response( + message=( + f"Invalid 'response_id': '{response_id}'. " + "Expected an ID that begins with 'resp'." + ), + err_type="invalid_request_error", + param="response_id", + ) + + def _make_not_found_error(self, response_id: str): + return self.create_error_response( + message=f"Response with id '{response_id}' not found.", + err_type="invalid_request_error", + status_code=HTTPStatus.NOT_FOUND, + param="response_id", + ) + + async def responses_stream_generator( + self, + request: ResponsesRequest, + sampling_params: Any, + result_generator: AsyncIterator[StreamingHarmonyContext], + context: StreamingHarmonyContext, + model_name: str, + tokenizer: Any, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> AsyncGenerator[str, None]: + # TODO: + # 1. Handle disconnect + + created_time = created_time or int(time.time()) + + sequence_number = 0 + + def _send_event(event): + nonlocal sequence_number + # Set sequence_number if the event has this attribute + if hasattr(event, "sequence_number"): + event.sequence_number = sequence_number + sequence_number += 1 + # Get event type from the event's type field if it exists + event_type = getattr(event, "type", "unknown") + return ( + f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n" + ) + + current_content_index = 0 + current_output_index = 0 + current_item_id = f"item_{random_uuid()}" + sent_output_item_added = False + + initial_response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + yield _send_event( + openai_responses_types.ResponseCreatedEvent( + type="response.created", + sequence_number=-1, + response=initial_response, + ) + ) + yield _send_event( + openai_responses_types.ResponseInProgressEvent( + type="response.in_progress", + sequence_number=-1, + response=initial_response, + ) + ) + + async for ctx in result_generator: + + if ctx.is_expecting_start(): + current_output_index += 1 + sent_output_item_added = False + + if len(ctx.parser.messages) > 0: + previous_item = ctx.parser.messages[-1] + if previous_item.recipient is not None: + # Deal with tool call here + pass + elif previous_item.channel == "analysis": + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + type="reasoning", + summary=[], + content=[ + ResponseReasoningTextContent( + text=previous_item.content[0].text, + type="reasoning_text", + ), + ], + status="completed", + ) + yield _send_event( + openai_responses_types.ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + ) + ) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + ) + ) + elif previous_item.channel == "final": + text_content = openai_responses_types.ResponseOutputText( + type="output_text", + text=previous_item.content[0].text, + annotations=[], + ) + yield _send_event( + openai_responses_types.ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + logprobs=[], + item_id=current_item_id, + ) + ) + yield _send_event( + openai_responses_types.ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=text_content, + ) + ) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[text_content], + status="completed", + ), + ) + ) + + if ctx.parser.last_content_delta: + if ( + ctx.parser.current_channel == "final" + and ctx.parser.current_recipient is None + ): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + ) + ) + yield _send_event( + openai_responses_types.ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + ) + ) + yield _send_event( + openai_responses_types.ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + # TODO, use logprobs from ctx.last_request_output + logprobs=[], + ) + ) + elif ( + ctx.parser.current_channel == "analysis" + and ctx.parser.current_recipient is None + ): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + ) + ) + yield _send_event( + openai_responses_types.ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + # TODO: migrate this to + # ResponseReasoningTextContent for now + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + ) + ) + # TODO: migrate to OpenAI types once updated. + yield _send_event( + openai_responses_types.ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + delta=ctx.parser.last_content_delta, + sequence_number=-1, + ) + ) + + if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: + previous_item = ctx.parser.messages[-1] + if ( + self.supports_browsing + and previous_item.recipient is not None + and previous_item.recipient.startswith("browser.") + ): + function_name = previous_item.recipient[len("browser.") :] + action = None + parsed_args = json.loads(previous_item.content[0].text) + if function_name == "search": + action = openai_responses_types.response_function_web_search.ActionSearch( + type="search", + query=parsed_args["query"], + ) + elif function_name == "open": + action = openai_responses_types.response_function_web_search.ActionOpenPage( + type="open_page", + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) + elif function_name == "find": + action = openai_responses_types.response_function_web_search.ActionFind( + type="find", + pattern=parsed_args["pattern"], + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) + else: + raise ValueError(f"Unknown function name: {function_name}") + + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.response_function_web_search.ResponseFunctionWebSearch( + # TODO: generate a unique id for web search call + type="web_search_call", + id=current_item_id, + action=action, + status="in_progress", + ), + ) + ) + yield _send_event( + openai_responses_types.ResponseWebSearchCallInProgressEvent( + type="response.web_search_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + ) + ) + yield _send_event( + openai_responses_types.ResponseWebSearchCallSearchingEvent( + type="response.web_search_call.searching", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + ) + ) + + # enqueue + yield _send_event( + openai_responses_types.ResponseWebSearchCallCompletedEvent( + type="response.web_search_call.completed", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + ) + ) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.ResponseFunctionWebSearch( + type="web_search_call", + id=current_item_id, + action=action, + status="completed", + ), + ) + ) + + if ( + self.supports_code_interpreter + and previous_item.recipient is not None + and previous_item.recipient.startswith("python") + ): + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code="", + container_id="auto", + outputs=[], + status="in_progress", + ), + ) + ) + yield _send_event( + openai_responses_types.ResponseCodeInterpreterCallInProgressEvent( + type="response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + ) + ) + # TODO: do we need to add delta event here? + yield _send_event( + openai_responses_types.ResponseCodeInterpreterCallCodeDoneEvent( + type="response.code_interpreter_call_code.done", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + code=previous_item.content[0].text, + ) + ) + yield _send_event( + openai_responses_types.ResponseCodeInterpreterCallInterpretingEvent( + type="response.code_interpreter_call.interpreting", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + ) + ) + yield _send_event( + openai_responses_types.ResponseCodeInterpreterCallCompletedEvent( + type="response.code_interpreter_call.completed", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + ) + ) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=previous_item.content[0].text, + container_id="auto", + # TODO: add outputs here + outputs=[], + status="completed", + ), + ) + ) + + async def empty_async_generator(): + if False: + yield + + final_response = await self.responses_full_generator( + request, + sampling_params, + empty_async_generator(), + context, + model_name, + tokenizer, + request_metadata, + created_time=created_time, + ) + # Convert final_response to the format expected by ResponseCompletedEvent + response_dict = final_response.model_dump() + + # Convert UsageInfo to ResponseUsage format + if response_dict.get("usage"): + usage_info = response_dict["usage"] + response_dict["usage"] = { + "input_tokens": usage_info.get("prompt_tokens", 0), + "input_tokens_details": { + "cached_tokens": usage_info.get("cached_tokens", 0) + }, + "output_tokens": usage_info.get("completion_tokens", 0), + "output_tokens_details": { + "reasoning_tokens": usage_info.get("reasoning_tokens", 0) + }, + "total_tokens": usage_info.get("total_tokens", 0), + } + + yield _send_event( + openai_responses_types.ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=response_dict, + ) + ) + + async def _generate_with_builtin_tools( + self, + request_id: str, + request_prompt: Any, + adapted_request: GenerateReqInput, + sampling_params: Any, + context: ConversationContext, + raw_request: Optional[Request] = None, + priority: Optional[int] = None, + **kwargs, + ) -> AsyncGenerator[Any, None]: + """Generate with builtin tool support for harmony-based models.""" + orig_priority = priority or 0 + + while True: + # Generate using SGLang's tokenizer manager + generator = self.tokenizer_manager.generate_request( + adapted_request, raw_request + ) + + async for res in generator: + context.append_output(res) + # NOTE(woosuk): The stop condition is handled by the engine. + yield context + + if not context.need_builtin_tool_call(): + # The model did not ask for a tool call, so we're done. + break + + # Call the tool and update the context with the result. + tool_output = await context.call_tool() + context.append_output(tool_output) + + # Prepare for the next generation turn + # Render the updated conversation for the next completion + prompt_token_ids = context.render_for_completion() + + # Update the adapted request with new prompt + adapted_request = GenerateReqInput( + input_ids=prompt_token_ids, + sampling_params=sampling_params, + stream=adapted_request.stream, + rid=request_id, + return_logprob=adapted_request.return_logprob, + logprob_start_len=adapted_request.logprob_start_len, + top_logprobs_num=adapted_request.top_logprobs_num, + return_text_in_logprobs=adapted_request.return_text_in_logprobs, + return_hidden_states=adapted_request.return_hidden_states, + background=adapted_request.background, + ) + + # Update sampling params with reduced max_tokens + if hasattr(sampling_params, "max_new_tokens") or isinstance( + sampling_params, dict + ): + context_len = getattr( + self.tokenizer_manager.model_config, "context_len", 4096 + ) + remaining_tokens = context_len - len(prompt_token_ids) - 1 + + if isinstance(sampling_params, dict): + sampling_params["max_new_tokens"] = max(remaining_tokens, 1) + else: + sampling_params.max_new_tokens = max(remaining_tokens, 1) + + # Slightly reduce priority for subsequent tool calls + priority = orig_priority - 1 diff --git a/python/sglang/srt/entrypoints/openai/tool_server.py b/python/sglang/srt/entrypoints/openai/tool_server.py new file mode 100644 index 000000000..fd66eb42b --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/tool_server.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +from abc import ABC, abstractmethod +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any + +logger = logging.getLogger(__name__) +try: + from mcp import ClientSession + from mcp.client.sse import sse_client + from mcp.types import ListToolsResult +except ImportError: + logger.warning("Ignoring mcp import error") + +from openai_harmony import ToolDescription, ToolNamespaceConfig + + +async def list_server_and_tools(server_url: str): + + async with sse_client(url=server_url) as streams, ClientSession( + *streams + ) as session: + initialize_response = await session.initialize() + list_tools_response = await session.list_tools() + return initialize_response, list_tools_response + + +def trim_schema(schema: dict) -> dict: + # Turn JSON Schema from MCP generated into Harmony's variant. + if "title" in schema: + del schema["title"] + if "default" in schema and schema["default"] is None: + del schema["default"] + if "anyOf" in schema: + # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}] + # into "type": ["type-1", "type-2"] + # if there's more than 1 types, also remove "null" type as Harmony will + # just ignore it + types = [ + type_dict["type"] + for type_dict in schema["anyOf"] + if type_dict["type"] != "null" + ] + schema["type"] = types + del schema["anyOf"] + if "properties" in schema: + schema["properties"] = { + k: trim_schema(v) for k, v in schema["properties"].items() + } + return schema + + +def post_process_tools_description( + list_tools_result: "ListToolsResult", +) -> "ListToolsResult": + # Adapt the MCP tool result for Harmony + for tool in list_tools_result.tools: + tool.inputSchema = trim_schema(tool.inputSchema) + + # Some tools schema don't need to be part of the prompt (e.g. simple text + # in text out for Python) + list_tools_result.tools = [ + tool + for tool in list_tools_result.tools + if getattr(tool.annotations, "include_in_prompt", True) + ] + + return list_tools_result + + +class ToolServer(ABC): + + @abstractmethod + def has_tool(self, tool_name: str): + pass + + @abstractmethod + def get_tool_description(self, tool_name: str): + pass + + @abstractmethod + def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ... + + +class MCPToolServer(ToolServer): + + def __init__(self): + self.harmony_tool_descriptions = {} + + async def add_tool_server(self, server_url: str): + tool_urls = server_url.split(",") + self.harmony_tool_descriptions = {} + self.urls: dict[str, str] = {} + for url in tool_urls: + url = f"http://{url}/sse" + initialize_response, list_tools_response = await list_server_and_tools(url) + + list_tools_response = post_process_tools_description(list_tools_response) + + tool_from_mcp = ToolNamespaceConfig( + name=initialize_response.serverInfo.name, + description=initialize_response.instructions, + tools=[ + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.inputSchema, + ) + for tool in list_tools_response.tools + ], + ) + self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp + if tool_from_mcp.name not in self.urls: + self.urls[tool_from_mcp.name] = url + else: + logger.warning( + "Tool %s already exists. Ignoring duplicate tool server %s", + tool_from_mcp.name, + url, + ) + + def has_tool(self, tool_name: str): + return tool_name in self.harmony_tool_descriptions + + def get_tool_description(self, tool_name: str): + return self.harmony_tool_descriptions.get(tool_name) + + @asynccontextmanager + async def get_tool_session(self, tool_name: str): + url = self.urls.get(tool_name) + if url: + async with sse_client(url=url) as streams, ClientSession( + *streams + ) as session: + await session.initialize() + yield session + else: + logger.warning("Tool %s not found", tool_name) + + +class DemoToolServer(ToolServer): + + def __init__(self): + from sglang.srt.entrypoints.tool import ( + HarmonyBrowserTool, + HarmonyPythonTool, + Tool, + ) + + self.tools: dict[str, Tool] = {} + browser_tool = HarmonyBrowserTool() + if browser_tool.enabled: + self.tools["browser"] = browser_tool + python_tool = HarmonyPythonTool() + if python_tool.enabled: + self.tools["python"] = python_tool + + def has_tool(self, tool_name: str): + return tool_name in self.tools + + def get_tool_description(self, tool_name: str): + if tool_name not in self.tools: + return None + if tool_name == "browser": + return ToolNamespaceConfig.browser() + elif tool_name == "python": + return ToolNamespaceConfig.python() + else: + raise ValueError(f"Unknown tool {tool_name}") + + @asynccontextmanager + async def get_tool_session(self, tool_name: str): + yield self.tools[tool_name] diff --git a/python/sglang/srt/entrypoints/tool.py b/python/sglang/srt/entrypoints/tool.py new file mode 100644 index 000000000..05c1c8ede --- /dev/null +++ b/python/sglang/srt/entrypoints/tool.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + # Avoid circular import. + from sglang.srt.entrypoints.context import ConversationContext + +logger = logging.getLogger(__name__) + + +class Tool(ABC): + + @abstractmethod + async def get_result(self, context: "ConversationContext") -> Any: + pass + + +class HarmonyBrowserTool(Tool): + + def __init__(self): + self.enabled = True + exa_api_key = os.getenv("EXA_API_KEY") + if not exa_api_key: + self.enabled = False + logger.warning_once("EXA_API_KEY is not set, browsing is disabled") + return + + try: + from gpt_oss.tools.simple_browser import SimpleBrowserTool + from gpt_oss.tools.simple_browser.backend import ExaBackend + except ImportError: + self.enabled = False + logger.warning_once("gpt_oss is not installed, browsing is disabled") + return + + browser_backend = ExaBackend(source="web", api_key=exa_api_key) + self.browser_tool = SimpleBrowserTool(backend=browser_backend) + logger.info_once("Browser tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from sglang.srt.entrypoints.context import HarmonyContext + + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.browser_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.browser_tool.tool_config + + +class HarmonyPythonTool(Tool): + + def __init__(self): + self.enabled = True + + try: + from gpt_oss.tools.python_docker.docker_tool import PythonTool + except ImportError: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed, code interpreter is disabled" + ) + return + + self.python_tool = PythonTool() + logger.info_once("Code interpreter tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from sglang.srt.entrypoints.context import HarmonyContext + + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.python_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.python_tool.tool_config diff --git a/python/sglang/srt/function_call/harmony_tool_parser.py b/python/sglang/srt/function_call/harmony_tool_parser.py new file mode 100644 index 000000000..10f82856b --- /dev/null +++ b/python/sglang/srt/function_call/harmony_tool_parser.py @@ -0,0 +1,130 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Harmony tool call parser for processing tool calls in harmony models.""" + +import uuid +from typing import List, Optional, Tuple + +from sglang.srt.entrypoints.openai.protocol import ( + ChatMessage, + FunctionResponse, + ToolCall, +) + + +class HarmonyToolCallParser: + """Parser for extracting tool calls from harmony model outputs.""" + + def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]: + """ + Extract tool call from a single message if it's a tool call. + + Args: + msg: The harmony message + + Returns: + ToolCall if the message is a tool call, None otherwise + """ + if ( + msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith("functions.") + ): + function_name = msg.recipient.split(".")[-1] + arguments = msg.content[0].text if msg.content else "{}" + + return ToolCall( + id=f"call_{uuid.uuid4().hex[:24]}", + function=FunctionResponse( + name=function_name, + arguments=arguments, + ), + ) + return None + + def process_streaming_chunk( + self, + harmony_parser, + index: int, + tool_call_trackers: dict, + stream_buffers: dict, + ) -> Tuple[Optional[dict], bool, Optional[str]]: + """ + Process a streaming chunk for tool calls. + + Args: + harmony_parser: The harmony parser instance + index: The choice index + tool_call_trackers: Dict tracking tool calls per choice + stream_buffers: Dict for buffering content + + Returns: + Tuple of (tool_call_data, is_tool_call, delta) + """ + # Check if we're in a tool call + is_tool_call = ( + harmony_parser.current_channel == "commentary" + and harmony_parser.current_recipient + and harmony_parser.current_recipient.startswith("functions.") + ) + + delta = harmony_parser.last_content_delta or "" + tool_call_data = None + + if is_tool_call: + # Handle tool call streaming + function_name = harmony_parser.current_recipient.split(".")[-1] + + # Track tool call indices per choice + if index not in tool_call_trackers: + tool_call_trackers[index] = {"count": 0, "current_function": None} + + # Check if we just started a new tool call + tool_call_tracker = tool_call_trackers[index] + if tool_call_tracker["current_function"] != function_name: + # New tool call started + tool_call_tracker["current_function"] = function_name + tool_call_index = tool_call_tracker["count"] + tool_call_tracker["count"] += 1 + + # Store the tool call index for this function + tool_call_key = f"{index}_{function_name}" + stream_buffers[tool_call_key] = { + "index": tool_call_index, + "content": "", + } + + tool_call_data = { + "id": f"call_{uuid.uuid4().hex[:24]}", + "index": tool_call_index, + "function_name": function_name, + "arguments": delta, + "is_first_chunk": True, + } + else: + # Subsequent chunks for the same tool call + tool_call_key = f"{index}_{function_name}" + tool_call_index = stream_buffers[tool_call_key]["index"] + + tool_call_data = { + "id": None, + "index": tool_call_index, + "function_name": None, + "arguments": delta, + "is_first_chunk": False, + } + + stream_buffers[tool_call_key]["content"] += delta + + return tool_call_data, is_tool_call, delta diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 811f108c7..29757b4b2 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -216,7 +216,7 @@ class DetokenizerManager: rids=recv_obj.rids, finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, - output_ids=None, + output_ids=recv_obj.decode_ids, prompt_tokens=recv_obj.prompt_tokens, completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c1c5f0735..1a0cbeadb 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -126,6 +126,9 @@ class GenerateReqInput: # For data parallel rank routing data_parallel_rank: Optional[int] = None + # For background responses (OpenAI responses API) + background: bool = False + def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) @@ -560,6 +563,9 @@ class EmbeddingReqInput: # For cross-encoder requests is_cross_encoder_request: bool = False + # For background responses (OpenAI responses API) + background: bool = False + def normalize_batch_and_arguments(self): # at least one of text, input_ids, or image should be provided if self.text is None and self.input_ids is None and self.image_data is None: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 635121920..a86899f6e 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin: req.send_decode_id_offset = len(decode_ids) read_offsets.append(read_offset) - if self.skip_tokenizer_init: - output_ids.append(req.output_ids[send_token_offset:]) + output_ids.append(req.output_ids[send_token_offset:]) req.send_token_offset = len(req.output_ids) skip_special_tokens.append(req.sampling_params.skip_special_tokens) spaces_between_special_tokens.append( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 63cbfd59e..498f0daef 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -750,7 +750,11 @@ class TokenizerManager: try: await asyncio.wait_for(state.event.wait(), timeout=4) except asyncio.TimeoutError: - if request is not None and await request.is_disconnected(): + if ( + request is not None + and not obj.background + and await request.is_disconnected() + ): # Abort the request for disconnected requests (non-streaming, waiting queue) self.abort_request(obj.rid) # Use exception to kill the whole call stack and asyncio task @@ -805,7 +809,11 @@ class TokenizerManager: if obj.stream: yield out else: - if request is not None and await request.is_disconnected(): + if ( + request is not None + and not obj.background + and await request.is_disconnected() + ): # Abort the request for disconnected requests (non-streaming, running) self.abort_request(obj.rid) # Use exception to kill the whole call stack and asyncio task @@ -1548,8 +1556,17 @@ class TokenizerManager: if isinstance(recv_obj, BatchStrOut): state.text += recv_obj.output_strs[i] + if state.obj.stream: + state.output_ids.extend(recv_obj.output_ids[i]) + output_token_ids = state.output_ids[state.last_output_offset :] + state.last_output_offset = len(state.output_ids) + else: + state.output_ids.extend(recv_obj.output_ids[i]) + output_token_ids = state.output_ids.copy() + out_dict = { "text": state.text, + "output_ids": output_token_ids, "meta_info": meta_info, } elif isinstance(recv_obj, BatchTokenIDOut): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 69c840a7b..2623a1027 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -274,6 +274,9 @@ class ServerArgs: enable_pdmux: bool = False sm_group_num: int = 3 + # For tool server + tool_server: Optional[str] = None + # Deprecated arguments enable_ep_moe: bool = False enable_deepep_moe: bool = False @@ -1916,6 +1919,14 @@ class ServerArgs: help="Disable mmap while loading weight using safetensors.", ) + # For tool server + parser.add_argument( + "--tool-server", + type=str, + default=None, + help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.", + ) + # Deprecated arguments parser.add_argument( "--enable-ep-moe", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2eb0d28b2..1e07a4136 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -41,6 +41,7 @@ import tempfile import threading import time import traceback +import uuid import warnings from collections import OrderedDict, defaultdict from contextlib import contextmanager @@ -233,6 +234,10 @@ def is_flashinfer_available(): return importlib.util.find_spec("flashinfer") is not None and is_cuda() +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var( "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" )