Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
2
vllm/entrypoints/openai/responses/__init__.py
Normal file
2
vllm/entrypoints/openai/responses/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
141
vllm/entrypoints/openai/responses/api_router.py
Normal file
141
vllm/entrypoints/openai/responses/api_router.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
StreamingResponsesResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def responses(request: Request) -> OpenAIServingResponses | None:
|
||||
return request.app.state.openai_serving_responses
|
||||
|
||||
|
||||
async def _convert_stream_to_sse_events(
|
||||
generator: AsyncGenerator[StreamingResponsesResponse, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Convert the generator to a stream of events in SSE format"""
|
||||
async for event in generator:
|
||||
event_type = getattr(event, "type", "unknown")
|
||||
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
|
||||
event_data = (
|
||||
f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n"
|
||||
)
|
||||
yield event_data
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/responses",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_responses(request: ResponsesRequest, raw_request: Request):
|
||||
handler = responses(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Responses API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.create_responses(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, ResponsesResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(
|
||||
content=_convert_stream_to_sse_events(generator), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/responses/{response_id}")
|
||||
@load_aware_call
|
||||
async def retrieve_responses(
|
||||
response_id: str,
|
||||
raw_request: Request,
|
||||
starting_after: int | None = None,
|
||||
stream: bool | None = False,
|
||||
):
|
||||
handler = responses(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Responses API"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await handler.retrieve_responses(
|
||||
response_id,
|
||||
starting_after=starting_after,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
response = handler.create_error_response(e)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
elif isinstance(response, ResponsesResponse):
|
||||
return JSONResponse(content=response.model_dump())
|
||||
return StreamingResponse(
|
||||
content=_convert_stream_to_sse_events(response), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/responses/{response_id}/cancel")
|
||||
@load_aware_call
|
||||
async def cancel_responses(response_id: str, raw_request: Request):
|
||||
handler = responses(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Responses API"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await handler.cancel_responses(response_id)
|
||||
except Exception as e:
|
||||
response = handler.create_error_response(e)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
918
vllm/entrypoints/openai/responses/context.py
Normal file
918
vllm/entrypoints/openai/responses/context.py
Normal file
@@ -0,0 +1,918 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Final, Union
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.mcp.tool import Tool
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
FunctionCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
get_encoding,
|
||||
get_streamable_parser_for_assistant,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.parser.responses_parser import (
|
||||
get_responses_parser_for_simple_context,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponseRawMessageAndToken,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This is currently needed as the tool type doesn't 1:1 match the
|
||||
# tool namespace, which is what is used to look up the
|
||||
# connection to the tool server
|
||||
_TOOL_NAME_TO_TYPE_MAP = {
|
||||
"browser": "web_search_preview",
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}"
|
||||
)
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnMetrics:
|
||||
"""Tracks token and toolcall details for a single conversation turn."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cached_input_tokens: int = 0,
|
||||
tool_output_tokens: int = 0,
|
||||
) -> None:
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
self.cached_input_tokens = cached_input_tokens
|
||||
self.tool_output_tokens = tool_output_tokens
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset counters for a new turn."""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.cached_input_tokens = 0
|
||||
self.tool_output_tokens = 0
|
||||
|
||||
def copy(self) -> "TurnMetrics":
|
||||
"""Create a copy of this turn's token counts."""
|
||||
return TurnMetrics(
|
||||
self.input_tokens,
|
||||
self.output_tokens,
|
||||
self.cached_input_tokens,
|
||||
self.tool_output_tokens,
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
@abstractmethod
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_tool_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(self) -> list[Message]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_for_completion(self) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
def _create_json_parse_error_messages(
|
||||
last_msg: Message, e: json.JSONDecodeError
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Creates an error message when json parse failed.
|
||||
"""
|
||||
error_msg = (
|
||||
f"Error parsing tool arguments as JSON: {str(e)}. "
|
||||
"Please ensure the tool call arguments are valid JSON and try again."
|
||||
)
|
||||
content = TextContent(text=error_msg)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
"""This is a context that cannot handle MCP tool calls"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
|
||||
# Accumulated final output for streaming mode
|
||||
self._accumulated_text: str = ""
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
self._accumulated_logprobs: list = []
|
||||
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
# todo num_reasoning_tokens is not implemented yet.
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for SimpleContext
|
||||
self.all_turn_metrics = []
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
if not isinstance(output, RequestOutput):
|
||||
raise ValueError("SimpleContext only supports RequestOutput.")
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
# Accumulate text, token_ids, and logprobs for streaming mode
|
||||
delta_output = output.outputs[0]
|
||||
self._accumulated_text += delta_output.text
|
||||
self._accumulated_token_ids.extend(delta_output.token_ids)
|
||||
if delta_output.logprobs is not None:
|
||||
self._accumulated_logprobs.extend(delta_output.logprobs)
|
||||
|
||||
if len(self.input_messages) == 0:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
self.input_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def output_messages(self) -> list[ResponseRawMessageAndToken]:
|
||||
"""Return consolidated output as a single message.
|
||||
|
||||
In streaming mode, text and tokens are accumulated across many deltas.
|
||||
This property returns them as a single entry rather than one per delta.
|
||||
"""
|
||||
if not self._accumulated_text and not self._accumulated_token_ids:
|
||||
return []
|
||||
return [
|
||||
ResponseRawMessageAndToken(
|
||||
message=self._accumulated_text,
|
||||
tokens=list(self._accumulated_token_ids),
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def final_output(self) -> RequestOutput | None:
|
||||
"""Return the final output, with complete text/token_ids/logprobs."""
|
||||
if self.last_output is not None and self.last_output.outputs:
|
||||
assert isinstance(self.last_output, RequestOutput)
|
||||
final_output = copy.copy(self.last_output)
|
||||
# copy inner item to avoid modify last_output
|
||||
final_output.outputs = [replace(item) for item in self.last_output.outputs]
|
||||
final_output.outputs[0].text = self._accumulated_text
|
||||
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
|
||||
if self._accumulated_logprobs:
|
||||
final_output.outputs[0].logprobs = self._accumulated_logprobs
|
||||
return final_output
|
||||
return self.last_output
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class ParsableContext(ConversationContext):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
|
||||
request: ResponsesRequest,
|
||||
available_tools: list[str] | None,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for ParsableContext
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
|
||||
if reasoning_parser_cls is None:
|
||||
raise ValueError("reasoning_parser_cls must be provided.")
|
||||
|
||||
self.parser = get_responses_parser_for_simple_context(
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
tool_parser_cls=tool_parser_cls,
|
||||
)
|
||||
self.tool_parser_cls = tool_parser_cls
|
||||
self.request = request
|
||||
|
||||
self.available_tools = available_tools or []
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
self.parser.process(output.outputs[0])
|
||||
output_token_ids = output.outputs[0].token_ids or []
|
||||
self._accumulated_token_ids.extend(output_token_ids)
|
||||
|
||||
# only store if enable_response_messages is True, save memory
|
||||
if self.request.enable_response_messages:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
if len(self.input_messages) == 0:
|
||||
self.input_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output.outputs[0].text,
|
||||
tokens=output.outputs[0].token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
|
||||
self.parser.response_messages.extend(output)
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
"""Return true if the last message is a builtin tool call
|
||||
that the request has enabled."""
|
||||
last_message = self.parser.response_messages[-1]
|
||||
if last_message.type != "function_call":
|
||||
return False
|
||||
if last_message.name in ("code_interpreter", "python"):
|
||||
return "python" in self.available_tools
|
||||
if last_message.name == "web_search_preview":
|
||||
return "browser" in self.available_tools
|
||||
if last_message.name.startswith("container"):
|
||||
return "container" in self.available_tools
|
||||
return False
|
||||
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
|
||||
) -> list[ResponseInputOutputItem]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result_parsable_context(self)
|
||||
args = json.loads(last_msg.arguments)
|
||||
param = {
|
||||
"code": args["code"],
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"mcpo_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
return [message]
|
||||
|
||||
async def call_search_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
|
||||
) -> list[ResponseInputOutputItem]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result_parsable_context(self)
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.arguments)
|
||||
result = await tool_session.call_tool("search", args)
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"fco_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
return [message]
|
||||
|
||||
async def call_container_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result_parsable_context(self)
|
||||
# tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.arguments)
|
||||
result = await tool_session.call_tool("exec", args)
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"fco_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
return [message]
|
||||
|
||||
async def call_tool(self) -> list[ResponseInputOutputItem]:
|
||||
if not self.parser.response_messages:
|
||||
return []
|
||||
last_msg = self.parser.response_messages[-1]
|
||||
# change this to a mcp_ function call
|
||||
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
|
||||
self.parser.response_messages[-1] = last_msg
|
||||
if last_msg.name == "code_interpreter":
|
||||
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
|
||||
elif last_msg.name == "web_search_preview":
|
||||
return await self.call_search_tool(self._tool_sessions["browser"], last_msg)
|
||||
elif last_msg.name.startswith("container"):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg
|
||||
)
|
||||
return []
|
||||
|
||||
def render_for_completion(self):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name in self._tool_sessions:
|
||||
continue
|
||||
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = (
|
||||
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||
)
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id, headers)
|
||||
)
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info(
|
||||
"Cleaning up tool session for %s", tool_session._client_info
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(
|
||||
*(
|
||||
cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
available_tools: list[str],
|
||||
):
|
||||
self._messages = messages
|
||||
self.finish_reason: str | None = None
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
self.num_tool_output_tokens = 0
|
||||
|
||||
# Turn tracking - replaces multiple individual tracking variables
|
||||
self.current_turn_metrics = TurnMetrics()
|
||||
# Track metrics for all turns
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
self.is_first_turn = True
|
||||
self.first_tok_of_message = True # For streaming support
|
||||
|
||||
def _update_num_reasoning_tokens(self):
|
||||
channel = self.parser.current_channel
|
||||
if channel == "analysis":
|
||||
self.num_reasoning_tokens += 1
|
||||
elif channel == "commentary" and self.parser.current_recipient is not None:
|
||||
# Tool interactions (python/browser/container) are hidden.
|
||||
# Preambles (recipient=None) are visible user text.
|
||||
self.num_reasoning_tokens += 1
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self._update_prefill_token_usage(output)
|
||||
self._update_decode_token_usage(output)
|
||||
# Append current turn to all turn list for next turn's calculations
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# append_output is called only once before tool calling
|
||||
# in non-streaming case
|
||||
# so we can append all the parser messages to _messages
|
||||
output_msgs = self.parser.messages
|
||||
# The responses finish reason is set in the last message
|
||||
self.finish_reason = output.outputs[0].finish_reason
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
output_msgs = output
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
||||
"""Update token usage statistics for the prefill phase of generation.
|
||||
|
||||
The prefill phase processes the input prompt tokens. This method:
|
||||
1. Counts the prompt tokens for this turn
|
||||
2. Calculates tool output tokens for multi-turn conversations
|
||||
3. Updates cached token counts
|
||||
4. Tracks state for next turn calculations
|
||||
|
||||
Tool output tokens are calculated as:
|
||||
current_prompt_tokens - last_turn_prompt_tokens -
|
||||
last_turn_output_tokens
|
||||
This represents tokens added between turns (typically tool responses).
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing prompt token information
|
||||
"""
|
||||
if output.prompt_token_ids is not None:
|
||||
this_turn_input_tokens = len(output.prompt_token_ids)
|
||||
else:
|
||||
this_turn_input_tokens = 0
|
||||
logger.error("RequestOutput appended contains no prompt_token_ids.")
|
||||
|
||||
# Update current turn input tokens
|
||||
self.current_turn_metrics.input_tokens = this_turn_input_tokens
|
||||
self.num_prompt_tokens += this_turn_input_tokens
|
||||
|
||||
# Calculate tool tokens (except on first turn)
|
||||
if self.is_first_turn:
|
||||
self.is_first_turn = False
|
||||
else:
|
||||
previous_turn = self.all_turn_metrics[-1]
|
||||
# start counting tool after first turn
|
||||
# tool tokens = this turn prefill - last turn prefill -
|
||||
# last turn decode
|
||||
this_turn_tool_tokens = (
|
||||
self.current_turn_metrics.input_tokens
|
||||
- previous_turn.input_tokens
|
||||
- previous_turn.output_tokens
|
||||
)
|
||||
|
||||
# Handle negative tool token counts (shouldn't happen in normal
|
||||
# cases)
|
||||
if this_turn_tool_tokens < 0:
|
||||
logger.error(
|
||||
"Negative tool output tokens calculated: %d "
|
||||
"(current_input=%d, previous_input=%d, "
|
||||
"previous_output=%d). Setting to 0.",
|
||||
this_turn_tool_tokens,
|
||||
self.current_turn_metrics.input_tokens,
|
||||
previous_turn.input_tokens,
|
||||
previous_turn.output_tokens,
|
||||
)
|
||||
this_turn_tool_tokens = 0
|
||||
|
||||
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||
self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens
|
||||
|
||||
# Update cached tokens
|
||||
num_cached_token = output.num_cached_tokens
|
||||
if num_cached_token is not None:
|
||||
self.num_cached_tokens += num_cached_token
|
||||
self.current_turn_metrics.cached_input_tokens = num_cached_token
|
||||
|
||||
def _update_decode_token_usage(self, output: RequestOutput) -> int:
|
||||
"""Update token usage statistics for the decode phase of generation.
|
||||
|
||||
The decode phase processes the generated output tokens. This method:
|
||||
1. Counts output tokens from all completion outputs
|
||||
2. Updates the total output token count
|
||||
3. Tracks tokens generated in the current turn
|
||||
|
||||
In streaming mode, this is called for each token generated.
|
||||
In non-streaming mode, this is called once with all output tokens.
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing generated token information
|
||||
|
||||
Returns:
|
||||
int: Number of output tokens processed in this call
|
||||
"""
|
||||
updated_output_token_count = 0
|
||||
if output.outputs:
|
||||
for completion_output in output.outputs:
|
||||
# only keep last round
|
||||
updated_output_token_count += len(completion_output.token_ids)
|
||||
self.num_output_tokens += updated_output_token_count
|
||||
self.current_turn_metrics.output_tokens += updated_output_token_count
|
||||
return updated_output_token_count
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
if recipient is None:
|
||||
return False
|
||||
if recipient.startswith("browser."):
|
||||
return "browser" in self.available_tools
|
||||
if recipient.startswith("python"):
|
||||
return "python" in self.available_tools
|
||||
if recipient.startswith("container."):
|
||||
return "container" in self.available_tools
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
return []
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self._tool_sessions["browser"], last_msg
|
||||
)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg
|
||||
)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg
|
||||
)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
return render_for_completion(self.messages)
|
||||
|
||||
async def call_search_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
"code": last_msg.content[0].text,
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name="python")
|
||||
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT,
|
||||
)
|
||||
]
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = (
|
||||
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||
)
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id, headers)
|
||||
)
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info(
|
||||
"Cleaning up tool session for %s", tool_session._client_info
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(
|
||||
*(
|
||||
cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_output = None
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.encoding = get_encoding()
|
||||
self.last_tok = None
|
||||
self.first_tok_of_message = True
|
||||
self.last_content_delta = None
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
self.last_content_delta = None
|
||||
if self.first_tok_of_message:
|
||||
self._update_prefill_token_usage(output)
|
||||
# Reset self.first_tok_of_message if needed:
|
||||
# if the current token is the last one of the current message
|
||||
# (finished=True), then the next token processed will mark the
|
||||
# beginning of a new message
|
||||
self.first_tok_of_message = output.finished
|
||||
last_delta_text = ""
|
||||
for tok in output.outputs[0].token_ids:
|
||||
self.parser.process(tok)
|
||||
last_delta_text += self.parser.last_content_delta or ""
|
||||
if last_delta_text:
|
||||
self.last_content_delta = last_delta_text
|
||||
self._update_decode_token_usage(output)
|
||||
|
||||
# For streaming, update previous turn when message is complete
|
||||
if output.finished:
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
|
||||
self._messages.extend(
|
||||
self.parser.messages[len(self._messages) - self.num_init_messages :]
|
||||
)
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
msg = output[0]
|
||||
# Sometimes the recipient is not set for tool messages,
|
||||
# so we set it to "assistant"
|
||||
if msg.author.role == Role.TOOL and msg.recipient is None:
|
||||
msg.recipient = "assistant"
|
||||
toks = self.encoding.render(msg)
|
||||
for tok in toks:
|
||||
self.parser.process(tok)
|
||||
self.last_tok = toks[-1]
|
||||
# TODO: add tool_output messages to self._messages
|
||||
|
||||
def is_expecting_start(self) -> bool:
|
||||
return self.parser.state == StreamState.EXPECT_START
|
||||
|
||||
def is_assistant_action_turn(self) -> bool:
|
||||
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
# now this list of tokens as next turn's starting tokens
|
||||
# `<|start|>assistant`,
|
||||
# we need to process them in parser.
|
||||
rendered_tokens = super().render_for_completion()
|
||||
|
||||
last_n = -1
|
||||
to_process = []
|
||||
while rendered_tokens[last_n] != self.last_tok:
|
||||
to_process.append(rendered_tokens[last_n])
|
||||
last_n -= 1
|
||||
for tok in reversed(to_process):
|
||||
self.parser.process(tok)
|
||||
|
||||
return rendered_tokens
|
||||
552
vllm/entrypoints/openai/responses/harmony.py
Normal file
552
vllm/entrypoints/openai/responses/harmony.py
Normal file
@@ -0,0 +1,552 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Harmony ↔ Responses API conversion utilities.
|
||||
|
||||
Handles two directions:
|
||||
1. Response Input → Harmony Messages (input parsing)
|
||||
2. Harmony Messages → Response Output Items (output parsing)
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
ResponseOutputItem,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ActionFind,
|
||||
ActionOpenPage,
|
||||
ActionSearch,
|
||||
ResponseFunctionWebSearch,
|
||||
)
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai_harmony import Author, Message, Role, StreamableParser, TextContent
|
||||
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
BUILTIN_TOOL_TO_MCP_SERVER_LABEL,
|
||||
flatten_chat_text_content,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Private helpers for input parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_harmony_format_message(chat_msg: dict) -> Message:
|
||||
"""Reconstruct a Message from Harmony-format dict,
|
||||
preserving channel, recipient, and content_type."""
|
||||
author_dict = chat_msg["author"]
|
||||
role = author_dict.get("role")
|
||||
name = author_dict.get("name")
|
||||
|
||||
raw_content = chat_msg.get("content", "")
|
||||
if isinstance(raw_content, list):
|
||||
# TODO: Support refusal and non-text content types.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in raw_content]
|
||||
elif isinstance(raw_content, str):
|
||||
contents = [TextContent(text=raw_content)]
|
||||
else:
|
||||
contents = [TextContent(text="")]
|
||||
|
||||
if name:
|
||||
msg = Message.from_author_and_contents(Author.new(Role(role), name), contents)
|
||||
else:
|
||||
msg = Message.from_role_and_contents(Role(role), contents)
|
||||
|
||||
channel = chat_msg.get("channel")
|
||||
if channel:
|
||||
msg = msg.with_channel(channel)
|
||||
recipient = chat_msg.get("recipient")
|
||||
if recipient:
|
||||
msg = msg.with_recipient(recipient)
|
||||
content_type = chat_msg.get("content_type")
|
||||
if content_type:
|
||||
msg = msg.with_content_type(content_type)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def _parse_chat_format_message(chat_msg: dict) -> list[Message]:
|
||||
"""Parse an OpenAI chat-format dict into Harmony messages."""
|
||||
role = chat_msg.get("role")
|
||||
if role is None:
|
||||
raise ValueError(f"Message has no 'role' key: {chat_msg}")
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls")
|
||||
if role == "assistant" and tool_calls:
|
||||
msgs: list[Message] = []
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
if name and not name.startswith("functions."):
|
||||
name = f"functions.{name}"
|
||||
content = chat_msg.get("content", "") or ""
|
||||
content = flatten_chat_text_content(content)
|
||||
# NOTE: .with_recipient("assistant") is required on tool messages
|
||||
# to match parse_chat_input_to_harmony_message behavior and ensure
|
||||
# proper routing in the Harmony protocol.
|
||||
msg = (
|
||||
Message.from_author_and_content(Author.new(Role.TOOL, name), content)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("assistant")
|
||||
)
|
||||
return [msg]
|
||||
|
||||
# Default: user/assistant/system messages
|
||||
content = chat_msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
return [msg]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Public input parsing functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def response_input_to_harmony(
|
||||
response_msg: ResponseInputOutputItem,
|
||||
prev_responses: list[ResponseOutputItem | ResponseReasoningItem],
|
||||
) -> Message:
|
||||
"""Convert a single ResponseInputOutputItem into a Harmony Message."""
|
||||
if not isinstance(response_msg, dict):
|
||||
response_msg = response_msg.model_dump()
|
||||
if "type" not in response_msg or response_msg["type"] == "message":
|
||||
role = response_msg["role"]
|
||||
content = response_msg["content"]
|
||||
# Add prefix for developer messages.
|
||||
# <|start|>developer<|message|># Instructions {instructions}<|end|>
|
||||
text_prefix = "Instructions:\n" if role == "developer" else ""
|
||||
if isinstance(content, str):
|
||||
msg = Message.from_role_and_content(role, text_prefix + content)
|
||||
else:
|
||||
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
if role == "assistant":
|
||||
msg = msg.with_channel("final")
|
||||
elif response_msg["type"] == "function_call_output":
|
||||
call_id = response_msg["call_id"]
|
||||
call_response: ResponseFunctionToolCall | None = None
|
||||
for prev_response in reversed(prev_responses):
|
||||
if (
|
||||
isinstance(prev_response, ResponseFunctionToolCall)
|
||||
and prev_response.call_id == call_id
|
||||
):
|
||||
call_response = prev_response
|
||||
break
|
||||
if call_response is None:
|
||||
raise ValueError(f"No call message found for {call_id}")
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
||||
response_msg["output"],
|
||||
)
|
||||
elif response_msg["type"] == "reasoning":
|
||||
content = response_msg["content"]
|
||||
assert len(content) == 1
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
||||
elif response_msg["type"] == "function_call":
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{response_msg['name']}")
|
||||
msg = msg.with_content_type("json")
|
||||
else:
|
||||
raise ValueError(f"Unknown input type: {response_msg['type']}")
|
||||
return msg
|
||||
|
||||
|
||||
def response_previous_input_to_harmony(chat_msg) -> list[Message]:
|
||||
"""Parse a message from request.previous_input_messages
|
||||
into Harmony messages.
|
||||
|
||||
Supports both OpenAI chat format ({"role": "..."}) and
|
||||
Harmony format ({"author": {"role": "..."}}).
|
||||
"""
|
||||
if not isinstance(chat_msg, dict):
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
if "author" in chat_msg and isinstance(chat_msg.get("author"), dict):
|
||||
return [_parse_harmony_format_message(chat_msg)]
|
||||
|
||||
return _parse_chat_format_message(chat_msg)
|
||||
|
||||
|
||||
def construct_harmony_previous_input_messages(
|
||||
request: ResponsesRequest,
|
||||
) -> list[Message]:
|
||||
"""Build a Harmony message list from request.previous_input_messages.
|
||||
|
||||
Filters out system/developer messages to match OpenAI behavior where
|
||||
instructions are always taken from the most recent Responses API request.
|
||||
"""
|
||||
messages: list[Message] = []
|
||||
if request.previous_input_messages:
|
||||
for message in request.previous_input_messages:
|
||||
# Handle both Message objects and dictionary inputs
|
||||
if isinstance(message, Message):
|
||||
message_role = message.author.role
|
||||
if message_role == Role.SYSTEM or message_role == Role.DEVELOPER:
|
||||
continue
|
||||
messages.append(message)
|
||||
else:
|
||||
harmony_messages = response_previous_input_to_harmony(message)
|
||||
for harmony_msg in harmony_messages:
|
||||
message_role = harmony_msg.author.role
|
||||
if message_role == Role.SYSTEM or message_role == Role.DEVELOPER:
|
||||
continue
|
||||
messages.append(harmony_msg)
|
||||
return messages
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Private helpers for output parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem:
|
||||
"""Parse browser tool calls (search, open, find) into web search items."""
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
|
||||
# Parse JSON args (with retry detection)
|
||||
try:
|
||||
browser_call = json.loads(content.text)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Invalid JSON in browser tool call, using error placeholder: %s",
|
||||
content.text,
|
||||
)
|
||||
json_retry_output_message = (
|
||||
f"Invalid JSON args, caught and retried: {content.text}"
|
||||
)
|
||||
browser_call = {
|
||||
"query": json_retry_output_message,
|
||||
"url": json_retry_output_message,
|
||||
"pattern": json_retry_output_message,
|
||||
}
|
||||
|
||||
# Create appropriate action based on recipient
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search"
|
||||
)
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(
|
||||
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
|
||||
)
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(
|
||||
pattern=browser_call.get("pattern", ""),
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
|
||||
return ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
|
||||
|
||||
def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
|
||||
"""Parse function calls into function tool call items."""
|
||||
function_name = recipient.split(".")[-1]
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"fc_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_reasoning(message: Message) -> list[ResponseOutputItem]:
|
||||
"""Parse reasoning/analysis content into reasoning items."""
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(text=content.text, type="reasoning_text")
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_final_message(message: Message) -> ResponseOutputItem:
|
||||
"""Parse final channel messages into output message items."""
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
return ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
|
||||
def _parse_mcp_recipient(recipient: str) -> tuple[str, str]:
|
||||
"""Parse MCP recipient into (server_label, tool_name).
|
||||
|
||||
For dotted recipients like "repo_browser.list":
|
||||
- server_label: "repo_browser" (namespace/server)
|
||||
- tool_name: "list" (specific tool)
|
||||
|
||||
For simple recipients like "filesystem":
|
||||
- server_label: "filesystem"
|
||||
- tool_name: "filesystem"
|
||||
"""
|
||||
if "." in recipient:
|
||||
server_label = recipient.split(".")[0]
|
||||
tool_name = recipient.split(".")[-1]
|
||||
else:
|
||||
server_label = recipient
|
||||
tool_name = recipient
|
||||
return server_label, tool_name
|
||||
|
||||
|
||||
def _parse_mcp_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
|
||||
"""Parse MCP calls into MCP call items."""
|
||||
# Handle built-in tools that need server_label mapping
|
||||
if recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
|
||||
server_label = BUILTIN_TOOL_TO_MCP_SERVER_LABEL[recipient]
|
||||
tool_name = recipient
|
||||
else:
|
||||
server_label, tool_name = _parse_mcp_recipient(recipient)
|
||||
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
response_item = McpCall(
|
||||
arguments=content.text,
|
||||
type="mcp_call",
|
||||
name=tool_name,
|
||||
server_label=server_label,
|
||||
id=f"mcp_{random_uuid()}",
|
||||
status="completed",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_message_no_recipient(
|
||||
message: Message,
|
||||
) -> list[ResponseOutputItem]:
|
||||
"""Parse a Harmony message with no recipient based on its channel."""
|
||||
if message.channel == "analysis":
|
||||
return _parse_reasoning(message)
|
||||
|
||||
if message.channel in ("commentary", "final"):
|
||||
# Per Harmony format, preambles (commentary with no recipient) and
|
||||
# final channel content are both intended to be shown to end-users.
|
||||
# See: https://cookbook.openai.com/articles/openai-harmony
|
||||
return [_parse_final_message(message)]
|
||||
|
||||
raise ValueError(f"Unknown channel: {message.channel}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Public output parsing functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def harmony_to_response_output(message: Message) -> list[ResponseOutputItem]:
|
||||
"""Parse a Harmony message into a list of output response items.
|
||||
|
||||
This is the main dispatcher that routes based on channel and recipient.
|
||||
"""
|
||||
if message.author.role != "assistant":
|
||||
# This is a message from a tool to the assistant (e.g., search result).
|
||||
# Don't include it in the final output for now. This aligns with
|
||||
# OpenAI's behavior on models like o4-mini.
|
||||
return []
|
||||
|
||||
output_items: list[ResponseOutputItem] = []
|
||||
recipient = message.recipient
|
||||
|
||||
if recipient is not None:
|
||||
# Browser tool calls (browser.search, browser.open, browser.find)
|
||||
if recipient.startswith("browser."):
|
||||
output_items.append(_parse_browser_tool_call(message, recipient))
|
||||
|
||||
# Function calls (should only happen on commentary channel)
|
||||
elif message.channel == "commentary" and recipient.startswith("functions."):
|
||||
output_items.extend(_parse_function_call(message, recipient))
|
||||
|
||||
# Built-in MCP tools (python, browser, container)
|
||||
elif recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
|
||||
output_items.extend(_parse_reasoning(message))
|
||||
|
||||
# All other recipients are MCP calls
|
||||
else:
|
||||
output_items.extend(_parse_mcp_call(message, recipient))
|
||||
|
||||
# No recipient - handle based on channel for non-tool messages
|
||||
else:
|
||||
output_items.extend(_parse_message_no_recipient(message))
|
||||
|
||||
return output_items
|
||||
|
||||
|
||||
def parser_state_to_response_output(
|
||||
parser: StreamableParser,
|
||||
) -> list[ResponseOutputItem]:
|
||||
"""Extract in-progress response items from incomplete parser state.
|
||||
|
||||
Called when the parser has buffered content that hasn't formed a
|
||||
complete message yet (e.g., generation was cut short).
|
||||
"""
|
||||
if not parser.current_content:
|
||||
return []
|
||||
if parser.current_role != Role.ASSISTANT:
|
||||
return []
|
||||
current_recipient = parser.current_recipient
|
||||
if current_recipient is not None and current_recipient.startswith("browser."):
|
||||
return []
|
||||
|
||||
if current_recipient and parser.current_channel in ("commentary", "analysis"):
|
||||
if current_recipient.startswith("functions."):
|
||||
rid = random_uuid()
|
||||
return [
|
||||
ResponseFunctionToolCall(
|
||||
arguments=parser.current_content,
|
||||
call_id=f"call_{rid}",
|
||||
type="function_call",
|
||||
name=current_recipient.split(".")[-1],
|
||||
id=f"fc_{rid}",
|
||||
status="in_progress",
|
||||
)
|
||||
]
|
||||
# Built-in MCP tools (python, browser, container)
|
||||
elif current_recipient in BUILTIN_TOOL_TO_MCP_SERVER_LABEL:
|
||||
return [
|
||||
ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
]
|
||||
# All other recipients are MCP calls
|
||||
else:
|
||||
rid = random_uuid()
|
||||
server_label, tool_name = _parse_mcp_recipient(current_recipient)
|
||||
return [
|
||||
McpCall(
|
||||
arguments=parser.current_content,
|
||||
type="mcp_call",
|
||||
name=tool_name,
|
||||
server_label=server_label,
|
||||
id=f"mcp_{rid}",
|
||||
status="in_progress",
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "commentary":
|
||||
# Per Harmony format, preambles (commentary with no recipient) are
|
||||
# intended to be shown to end-users, unlike analysis channel content.
|
||||
output_text = ResponseOutputText(
|
||||
text=parser.current_content,
|
||||
annotations=[],
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
return [
|
||||
ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
role="assistant",
|
||||
status="incomplete",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "analysis":
|
||||
return [
|
||||
ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "final":
|
||||
output_text = ResponseOutputText(
|
||||
text=parser.current_content,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
role="assistant",
|
||||
# if the parser still has messages (ie if the generator got cut
|
||||
# abruptly), this should be incomplete
|
||||
status="incomplete",
|
||||
type="message",
|
||||
)
|
||||
return [text_item]
|
||||
|
||||
return []
|
||||
641
vllm/entrypoints/openai/responses/protocol.py
Normal file
641
vllm/entrypoints/openai/responses/protocol.py
Normal file
@@ -0,0 +1,641 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from openai.types.responses import (
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInputItemParam,
|
||||
ResponseMcpCallArgumentsDeltaEvent,
|
||||
ResponseMcpCallArgumentsDoneEvent,
|
||||
ResponseMcpCallCompletedEvent,
|
||||
ResponseMcpCallInProgressEvent,
|
||||
ResponseOutputItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponsePrompt,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseStatus,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
)
|
||||
from openai.types.responses import (
|
||||
ResponseCompletedEvent as OpenAIResponseCompletedEvent,
|
||||
)
|
||||
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
|
||||
from openai.types.responses import (
|
||||
ResponseInProgressEvent as OpenAIResponseInProgressEvent,
|
||||
)
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
|
||||
# Backward compatibility for OpenAI client versions
|
||||
try: # For older openai versions (< 1.100.0)
|
||||
from openai.types.responses import ResponseTextConfig
|
||||
except ImportError: # For newer openai versions (>= 1.100.0)
|
||||
from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
|
||||
|
||||
from openai.types.responses.response import IncompleteDetails, ToolChoice
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai.types.shared import Metadata, Reasoning
|
||||
from pydantic import (
|
||||
Field,
|
||||
ValidationError,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import (
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class InputTokensDetails(OpenAIBaseModel):
|
||||
cached_tokens: int
|
||||
input_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
cached_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OutputTokensDetails(OpenAIBaseModel):
|
||||
reasoning_tokens: int = 0
|
||||
tool_output_tokens: int = 0
|
||||
output_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseUsage(OpenAIBaseModel):
|
||||
input_tokens: int
|
||||
input_tokens_details: InputTokensDetails
|
||||
output_tokens: int
|
||||
output_tokens_details: OutputTokensDetails
|
||||
total_tokens: int
|
||||
|
||||
|
||||
def serialize_message(msg):
|
||||
"""
|
||||
Serializes a single message
|
||||
"""
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
elif hasattr(msg, "to_dict"):
|
||||
return msg.to_dict()
|
||||
else:
|
||||
# fallback to pyandic dump
|
||||
return msg.model_dump_json()
|
||||
|
||||
|
||||
def serialize_messages(msgs):
|
||||
"""
|
||||
Serializes multiple messages
|
||||
"""
|
||||
return [serialize_message(msg) for msg in msgs] if msgs else None
|
||||
|
||||
|
||||
class ResponseRawMessageAndToken(OpenAIBaseModel):
|
||||
"""Class to show the raw message.
|
||||
If message / tokens diverge, tokens is the source of truth"""
|
||||
|
||||
message: str
|
||||
tokens: list[int]
|
||||
type: Literal["raw_message_tokens"] = "raw_message_tokens"
|
||||
|
||||
|
||||
ResponseInputOutputMessage: TypeAlias = (
|
||||
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
|
||||
)
|
||||
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
|
||||
|
||||
|
||||
class ResponsesRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/responses/create
|
||||
background: bool | None = False
|
||||
include: (
|
||||
list[
|
||||
Literal[
|
||||
"code_interpreter_call.outputs",
|
||||
"computer_call_output.output.image_url",
|
||||
"file_search_call.results",
|
||||
"message.input_image.image_url",
|
||||
"message.output_text.logprobs",
|
||||
"reasoning.encrypted_content",
|
||||
],
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
input: str | list[ResponseInputOutputItem]
|
||||
instructions: str | None = None
|
||||
max_output_tokens: int | None = None
|
||||
max_tool_calls: int | None = None
|
||||
metadata: Metadata | None = None
|
||||
model: str | None = None
|
||||
logit_bias: dict[str, float] | None = None
|
||||
parallel_tool_calls: bool | None = True
|
||||
previous_response_id: str | None = None
|
||||
prompt: ResponsePrompt | None = None
|
||||
reasoning: Reasoning | None = None
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
|
||||
store: bool | None = True
|
||||
stream: bool | None = False
|
||||
temperature: float | None = None
|
||||
text: ResponseTextConfig | None = None
|
||||
tool_choice: ToolChoice = "auto"
|
||||
tools: list[Tool] = Field(default_factory=list)
|
||||
top_logprobs: int | None = 0
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
truncation: Literal["auto", "disabled"] | None = "disabled"
|
||||
user: str | None = None
|
||||
skip_special_tokens: bool = True
|
||||
include_stop_str_in_output: bool = False
|
||||
prompt_cache_key: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A key that was used to read from or write to the prompt cache."
|
||||
"Note: This field has not been implemented yet "
|
||||
"and vLLM will ignore it."
|
||||
),
|
||||
)
|
||||
|
||||
# --8<-- [start:responses-extra-params]
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"resp_{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
|
||||
enable_response_messages: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Dictates whether or not to return messages as part of the "
|
||||
"response object. Currently only supported for non-background."
|
||||
),
|
||||
)
|
||||
# similar to input_messages / output_messages in ResponsesResponse
|
||||
# we take in previous_input_messages (ie in harmony format)
|
||||
# this cannot be used in conjunction with previous_response_id
|
||||
# TODO: consider supporting non harmony messages as well
|
||||
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
|
||||
structured_outputs: StructuredOutputsParams | None = Field(
|
||||
default=None,
|
||||
description="Additional kwargs for structured outputs",
|
||||
)
|
||||
|
||||
repetition_penalty: float | None = None
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: str | list[str] | None = []
|
||||
ignore_eos: bool = False
|
||||
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with (list of) string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
from .utils import should_continue_final_message
|
||||
|
||||
# Check if we should continue the final message (partial completion)
|
||||
# This enables Anthropic-style partial message completion where the
|
||||
# user provides an incomplete assistant message to continue from.
|
||||
continue_final = should_continue_final_message(self.input)
|
||||
|
||||
reasoning = self.reasoning
|
||||
|
||||
return ChatParams(
|
||||
chat_template=default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs( # To remove unset values
|
||||
{},
|
||||
dict(
|
||||
add_generation_prompt=not continue_final,
|
||||
continue_final_message=continue_final,
|
||||
reasoning_effort=None if reasoning is None else reasoning.effort,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=self.max_output_tokens or 0,
|
||||
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_output_tokens",
|
||||
)
|
||||
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> SamplingParams:
|
||||
if self.max_output_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
else:
|
||||
max_tokens = min(self.max_output_tokens, default_max_tokens)
|
||||
|
||||
default_sampling_params = default_sampling_params or {}
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||
)
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||
)
|
||||
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get("repetition_penalty", 1.0)
|
||||
|
||||
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||
|
||||
# Structured output
|
||||
structured_outputs = self.structured_outputs
|
||||
|
||||
# Also check text.format for OpenAI-style json_schema
|
||||
if self.text is not None and self.text.format is not None:
|
||||
if structured_outputs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both structured_outputs and text.format"
|
||||
)
|
||||
response_format = self.text.format
|
||||
if (
|
||||
response_format.type == "json_schema"
|
||||
and response_format.schema_ is not None
|
||||
):
|
||||
structured_outputs = StructuredOutputsParams(
|
||||
json=response_format.schema_ # type: ignore[call-arg]
|
||||
# --follow-imports skip hides the class definition but also hides
|
||||
# multiple third party conflicts, so best of both evils
|
||||
)
|
||||
|
||||
stop = self.stop if self.stop else []
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=self.seed,
|
||||
ignore_eos=self.ignore_eos,
|
||||
output_kind=(
|
||||
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
|
||||
),
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=self.logit_bias,
|
||||
extra_args=self.vllm_xargs or {},
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def is_include_output_logprobs(self) -> bool:
|
||||
"""Check if the request includes output logprobs."""
|
||||
if self.include is None:
|
||||
return False
|
||||
return (
|
||||
isinstance(self.include, list)
|
||||
and "message.output_text.logprobs" in self.include
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_background(cls, data):
|
||||
if not data.get("background"):
|
||||
return data
|
||||
if not data.get("store", True):
|
||||
raise ValueError("background can only be used when `store` is true")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_prompt(cls, data):
|
||||
if data.get("prompt") is not None:
|
||||
raise VLLMValidationError(
|
||||
"prompt template is not supported", parameter="prompt"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None and (
|
||||
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
def function_call_parsing(cls, data):
|
||||
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
|
||||
This ensures Pydantic can properly resolve union types in the input field.
|
||||
Function calls provided as dicts are converted to ResponseFunctionToolCall
|
||||
objects before validation, while invalid structures are left for Pydantic
|
||||
to reject with appropriate error messages.
|
||||
"""
|
||||
|
||||
input_data = data.get("input")
|
||||
|
||||
# Early return for None, strings, or bytes
|
||||
# (strings are iterable but shouldn't be processed)
|
||||
if input_data is None or isinstance(input_data, (str, bytes)):
|
||||
return data
|
||||
|
||||
# Convert iterators (like ValidatorIterator) to list
|
||||
if not isinstance(input_data, list):
|
||||
try:
|
||||
input_data = list(input_data)
|
||||
except TypeError:
|
||||
# Not iterable, leave as-is for Pydantic to handle
|
||||
return data
|
||||
|
||||
processed_input = []
|
||||
for item in input_data:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
try:
|
||||
processed_input.append(ResponseFunctionToolCall(**item))
|
||||
except ValidationError:
|
||||
# Let Pydantic handle validation for malformed function calls
|
||||
logger.debug(
|
||||
"Failed to parse function_call to ResponseFunctionToolCall, "
|
||||
"leaving for Pydantic validation"
|
||||
)
|
||||
processed_input.append(item)
|
||||
else:
|
||||
processed_input.append(item)
|
||||
|
||||
data["input"] = processed_input
|
||||
return data
|
||||
|
||||
|
||||
class ResponsesResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
# error: Optional[ResponseError] = None
|
||||
incomplete_details: IncompleteDetails | None = None
|
||||
instructions: str | None = None
|
||||
metadata: Metadata | None = None
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: list[ResponseOutputItem]
|
||||
parallel_tool_calls: bool
|
||||
temperature: float
|
||||
tool_choice: ToolChoice
|
||||
tools: list[Tool]
|
||||
top_p: float
|
||||
background: bool
|
||||
max_output_tokens: int
|
||||
max_tool_calls: int | None = None
|
||||
previous_response_id: str | None = None
|
||||
prompt: ResponsePrompt | None = None
|
||||
reasoning: Reasoning | None = None
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
|
||||
status: ResponseStatus
|
||||
text: ResponseTextConfig | None = None
|
||||
top_logprobs: int | None = None
|
||||
truncation: Literal["auto", "disabled"]
|
||||
usage: ResponseUsage | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:responses-response-extra-params]
|
||||
# These are populated when enable_response_messages is set to True
|
||||
# NOTE: custom serialization is needed
|
||||
# see serialize_input_messages and serialize_output_messages
|
||||
input_messages: ResponseInputOutputMessage | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If enable_response_messages, we can show raw token input to model."
|
||||
),
|
||||
)
|
||||
output_messages: ResponseInputOutputMessage | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If enable_response_messages, we can show raw token output of model."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:responses-response-extra-params]
|
||||
|
||||
# NOTE: openAI harmony doesn't serialize TextContent properly,
|
||||
# TODO: this fixes for TextContent, but need to verify for tools etc
|
||||
# https://github.com/openai/harmony/issues/78
|
||||
@field_serializer("output_messages", when_used="json")
|
||||
def serialize_output_messages(self, msgs, _info):
|
||||
return serialize_messages(msgs)
|
||||
|
||||
# NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
|
||||
# https://github.com/openai/harmony/issues/78
|
||||
@field_serializer("input_messages", when_used="json")
|
||||
def serialize_input_messages(self, msgs, _info):
|
||||
return serialize_messages(msgs)
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
model_name: str,
|
||||
created_time: int,
|
||||
output: list[ResponseOutputItem],
|
||||
status: ResponseStatus,
|
||||
usage: ResponseUsage | None = None,
|
||||
input_messages: ResponseInputOutputMessage | None = None,
|
||||
output_messages: ResponseInputOutputMessage | None = None,
|
||||
) -> "ResponsesResponse":
|
||||
incomplete_details: IncompleteDetails | None = None
|
||||
if status == "incomplete":
|
||||
incomplete_details = IncompleteDetails(reason="max_output_tokens")
|
||||
# TODO: implement the other reason for incomplete_details,
|
||||
# which is content_filter
|
||||
# incomplete_details = IncompleteDetails(reason='content_filter')
|
||||
return cls(
|
||||
id=request.request_id,
|
||||
created_at=created_time,
|
||||
incomplete_details=incomplete_details,
|
||||
instructions=request.instructions,
|
||||
metadata=request.metadata,
|
||||
model=model_name,
|
||||
output=output,
|
||||
input_messages=input_messages,
|
||||
output_messages=output_messages,
|
||||
parallel_tool_calls=request.parallel_tool_calls,
|
||||
temperature=sampling_params.temperature,
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
top_p=sampling_params.top_p,
|
||||
background=request.background,
|
||||
max_output_tokens=sampling_params.max_tokens,
|
||||
max_tool_calls=request.max_tool_calls,
|
||||
previous_response_id=request.previous_response_id,
|
||||
prompt=request.prompt,
|
||||
reasoning=request.reasoning,
|
||||
service_tier=request.service_tier,
|
||||
status=status,
|
||||
text=request.text,
|
||||
top_logprobs=sampling_params.logprobs,
|
||||
truncation=request.truncation,
|
||||
user=request.user,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
# TODO: this code can be removed once
|
||||
# https://github.com/openai/openai-python/issues/2634 has been resolved
|
||||
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
|
||||
content_index: int
|
||||
"""The index of the content part that is done."""
|
||||
|
||||
item_id: str
|
||||
"""The ID of the output item that the content part was added to."""
|
||||
|
||||
output_index: int
|
||||
"""The index of the output item that the content part was added to."""
|
||||
|
||||
part: ResponseReasoningTextContent
|
||||
"""The content part that is done."""
|
||||
|
||||
sequence_number: int
|
||||
"""The sequence number of this event."""
|
||||
|
||||
type: Literal["response.reasoning_part.done"]
|
||||
"""The type of the event. Always `response.reasoning_part.done`."""
|
||||
|
||||
|
||||
# TODO: this code can be removed once
|
||||
# https://github.com/openai/openai-python/issues/2634 has been resolved
|
||||
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
|
||||
content_index: int
|
||||
"""The index of the content part that is done."""
|
||||
|
||||
item_id: str
|
||||
"""The ID of the output item that the content part was added to."""
|
||||
|
||||
output_index: int
|
||||
"""The index of the output item that the content part was added to."""
|
||||
|
||||
part: ResponseReasoningTextContent
|
||||
"""The content part that is done."""
|
||||
|
||||
sequence_number: int
|
||||
"""The sequence number of this event."""
|
||||
|
||||
type: Literal["response.reasoning_part.added"]
|
||||
"""The type of the event. Always `response.reasoning_part.added`."""
|
||||
|
||||
|
||||
# vLLM Streaming Events
|
||||
# Note: we override the response type with the vLLM ResponsesResponse type
|
||||
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
|
||||
response: ResponsesResponse # type: ignore[override]
|
||||
|
||||
|
||||
class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
|
||||
response: ResponsesResponse # type: ignore[override]
|
||||
|
||||
|
||||
class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
|
||||
response: ResponsesResponse # type: ignore[override]
|
||||
|
||||
|
||||
StreamingResponsesResponse: TypeAlias = (
|
||||
ResponseCreatedEvent
|
||||
| ResponseInProgressEvent
|
||||
| ResponseCompletedEvent
|
||||
| ResponseOutputItemAddedEvent
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseContentPartAddedEvent
|
||||
| ResponseContentPartDoneEvent
|
||||
| ResponseReasoningTextDeltaEvent
|
||||
| ResponseReasoningTextDoneEvent
|
||||
| ResponseReasoningPartAddedEvent
|
||||
| ResponseReasoningPartDoneEvent
|
||||
| ResponseCodeInterpreterCallInProgressEvent
|
||||
| ResponseCodeInterpreterCallCodeDeltaEvent
|
||||
| ResponseWebSearchCallInProgressEvent
|
||||
| ResponseWebSearchCallSearchingEvent
|
||||
| ResponseWebSearchCallCompletedEvent
|
||||
| ResponseCodeInterpreterCallCodeDoneEvent
|
||||
| ResponseCodeInterpreterCallInterpretingEvent
|
||||
| ResponseCodeInterpreterCallCompletedEvent
|
||||
| ResponseMcpCallArgumentsDeltaEvent
|
||||
| ResponseMcpCallArgumentsDoneEvent
|
||||
| ResponseMcpCallInProgressEvent
|
||||
| ResponseMcpCallCompletedEvent
|
||||
)
|
||||
1723
vllm/entrypoints/openai/responses/serving.py
Normal file
1723
vllm/entrypoints/openai/responses/serving.py
Normal file
File diff suppressed because it is too large
Load Diff
798
vllm/entrypoints/openai/responses/streaming_events.py
Normal file
798
vllm/entrypoints/openai/responses/streaming_events.py
Normal file
@@ -0,0 +1,798 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Streaming SSE event builders for the Responses API.
|
||||
|
||||
Pure functions that translate streaming state + delta data into
|
||||
OpenAI Response API SSE events. Used by the streaming event
|
||||
processors in serving.py.
|
||||
|
||||
The file is organized as:
|
||||
1. StreamingState dataclass + utility helpers
|
||||
2. Shared leaf helpers — delta events (take plain strings, no context)
|
||||
3. Shared leaf helpers — done events (take plain strings, no context)
|
||||
4. Harmony-specific dispatchers (route ctx/previous_item → leaf helpers)
|
||||
5. Harmony-specific tool lifecycle helpers
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseCodeInterpreterToolCallParam,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionWebSearch,
|
||||
ResponseMcpCallArgumentsDeltaEvent,
|
||||
ResponseMcpCallArgumentsDoneEvent,
|
||||
ResponseMcpCallCompletedEvent,
|
||||
ResponseMcpCallInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
response_function_web_search,
|
||||
)
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai_harmony import Message as HarmonyMessage
|
||||
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.responses.context import StreamingHarmonyContext
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
StreamingResponsesResponse,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = {
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
"browser": "web_search_preview",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_mcp_name_label(recipient: str) -> tuple[str, str]:
|
||||
"""Resolve MCP tool name and server label from a recipient string.
|
||||
|
||||
- ``mcp.*`` recipients: strip prefix, use the bare name as both
|
||||
name and server_label.
|
||||
- Everything else: use the recipient as the name and look up the
|
||||
server_label in TOOL_NAME_TO_MCP_SERVER_LABEL.
|
||||
"""
|
||||
if recipient.startswith("mcp."):
|
||||
name = recipient[len("mcp.") :]
|
||||
return name, name
|
||||
return recipient, TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingState:
|
||||
"""Mutable state for streaming event processing."""
|
||||
|
||||
current_content_index: int = -1
|
||||
current_output_index: int = 0
|
||||
current_item_id: str = ""
|
||||
current_call_id: str = ""
|
||||
sent_output_item_added: bool = False
|
||||
is_first_function_call_delta: bool = False
|
||||
|
||||
def reset_for_new_item(self) -> None:
|
||||
"""Reset state when expecting a new output item."""
|
||||
self.current_output_index += 1
|
||||
self.sent_output_item_added = False
|
||||
self.is_first_function_call_delta = False
|
||||
self.current_call_id = ""
|
||||
|
||||
|
||||
def is_mcp_tool_by_namespace(recipient: str | None) -> bool:
|
||||
"""
|
||||
Determine if a tool call is an MCP tool based on recipient prefix.
|
||||
|
||||
- Tools starting with "functions." are function calls
|
||||
- Everything else is an MCP tool
|
||||
"""
|
||||
if recipient is None:
|
||||
return False
|
||||
|
||||
# Function calls have "functions." prefix
|
||||
# Everything else is an MCP tool
|
||||
return not recipient.startswith("functions.")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Shared leaf helpers — delta events
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_text_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for text content delta streaming."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"msg_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseOutputMessage(
|
||||
id=state.current_item_id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[],
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
state.current_content_index += 1
|
||||
events.append(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
content_index=state.current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseTextDeltaEvent(
|
||||
type="response.output_text.delta",
|
||||
sequence_number=-1,
|
||||
content_index=state.current_content_index,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
# TODO, use logprobs from ctx.last_request_output
|
||||
logprobs=[],
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_reasoning_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for reasoning text delta streaming."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"msg_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
id=state.current_item_id,
|
||||
summary=[],
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
state.current_content_index += 1
|
||||
events.append(
|
||||
ResponseReasoningPartAddedEvent(
|
||||
type="response.reasoning_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
content_index=state.current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text="",
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseReasoningTextDeltaEvent(
|
||||
type="response.reasoning_text.delta",
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
delta=delta,
|
||||
sequence_number=-1,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_function_call_delta_events(
|
||||
delta: str,
|
||||
function_name: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for function call argument deltas."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if state.is_first_function_call_delta is False:
|
||||
state.is_first_function_call_delta = True
|
||||
state.current_item_id = f"fc_{random_uuid()}"
|
||||
state.current_call_id = f"call_{random_uuid()}"
|
||||
tool_call_item = ResponseFunctionToolCall(
|
||||
name=function_name,
|
||||
type="function_call",
|
||||
id=state.current_item_id,
|
||||
call_id=state.current_call_id,
|
||||
arguments="",
|
||||
status="in_progress",
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=tool_call_item,
|
||||
)
|
||||
)
|
||||
# Always emit the delta (including on first call)
|
||||
events.append(
|
||||
ResponseFunctionCallArgumentsDeltaEvent(
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
type="response.function_call_arguments.delta",
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_mcp_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
recipient: str,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for MCP tool delta streaming."""
|
||||
name, server_label = _resolve_mcp_name_label(recipient)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"mcp_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=McpCall(
|
||||
type="mcp_call",
|
||||
id=state.current_item_id,
|
||||
name=name,
|
||||
arguments="",
|
||||
server_label=server_label,
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseMcpCallInProgressEvent(
|
||||
type="response.mcp_call.in_progress",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseMcpCallArgumentsDeltaEvent(
|
||||
type="response.mcp_call_arguments.delta",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_code_interpreter_delta_events(
|
||||
delta: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for code interpreter delta streaming."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
if not state.sent_output_item_added:
|
||||
state.sent_output_item_added = True
|
||||
state.current_item_id = f"tool_{random_uuid()}"
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseCodeInterpreterToolCallParam(
|
||||
type="code_interpreter_call",
|
||||
id=state.current_item_id,
|
||||
code=None,
|
||||
container_id="auto",
|
||||
outputs=None,
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallInProgressEvent(
|
||||
type="response.code_interpreter_call.in_progress",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent(
|
||||
type="response.code_interpreter_call_code.delta",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Shared leaf helpers — done events
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_text_output_done_events(
|
||||
text: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when a final text output item completes."""
|
||||
text_content = ResponseOutputText(
|
||||
type="output_text",
|
||||
text=text,
|
||||
annotations=[],
|
||||
)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseTextDoneEvent(
|
||||
type="response.output_text.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
text=text,
|
||||
logprobs=[],
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseContentPartDoneEvent(
|
||||
type="response.content_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
part=text_content,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseOutputMessage(
|
||||
id=state.current_item_id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[text_content],
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_reasoning_done_events(
|
||||
text: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when a reasoning (analysis) item completes."""
|
||||
content = ResponseReasoningTextContent(
|
||||
text=text,
|
||||
type="reasoning_text",
|
||||
)
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
content=[content],
|
||||
status="completed",
|
||||
id=state.current_item_id,
|
||||
summary=[],
|
||||
)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseReasoningTextDoneEvent(
|
||||
type="response.reasoning_text.done",
|
||||
item_id=state.current_item_id,
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
text=text,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseReasoningPartDoneEvent(
|
||||
type="response.reasoning_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
content_index=state.current_content_index,
|
||||
part=content,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=reasoning_item,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_function_call_done_events(
|
||||
function_name: str,
|
||||
arguments: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when a function call completes."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseFunctionCallArgumentsDoneEvent(
|
||||
type="response.function_call_arguments.done",
|
||||
arguments=arguments,
|
||||
name=function_name,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
)
|
||||
)
|
||||
function_call_item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
arguments=arguments,
|
||||
name=function_name,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
call_id=state.current_call_id,
|
||||
status="completed",
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=function_call_item,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_mcp_completion_events(
|
||||
recipient: str,
|
||||
arguments: str,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when an MCP tool call completes."""
|
||||
name, server_label = _resolve_mcp_name_label(recipient)
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseMcpCallArgumentsDoneEvent(
|
||||
type="response.mcp_call_arguments.done",
|
||||
arguments=arguments,
|
||||
name=name,
|
||||
item_id=state.current_item_id,
|
||||
output_index=state.current_output_index,
|
||||
sequence_number=-1,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseMcpCallCompletedEvent(
|
||||
type="response.mcp_call.completed",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=McpCall(
|
||||
type="mcp_call",
|
||||
arguments=arguments,
|
||||
name=name,
|
||||
id=state.current_item_id,
|
||||
server_label=server_label,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Harmony-specific dispatchers
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_content_delta_events(
|
||||
ctx: StreamingHarmonyContext,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for content delta streaming based on channel type.
|
||||
|
||||
This is a Harmony-specific dispatcher that extracts values from the
|
||||
Harmony context and delegates to shared leaf helpers.
|
||||
"""
|
||||
delta = ctx.last_content_delta
|
||||
if not delta:
|
||||
return []
|
||||
|
||||
channel = ctx.parser.current_channel
|
||||
recipient = ctx.parser.current_recipient
|
||||
|
||||
if channel in ("final", "commentary") and recipient is None:
|
||||
# Preambles (commentary with no recipient) and final messages
|
||||
# are both user-visible text.
|
||||
return emit_text_delta_events(delta, state)
|
||||
elif channel == "analysis" and recipient is None:
|
||||
return emit_reasoning_delta_events(delta, state)
|
||||
# built-in tools will be triggered on the analysis channel
|
||||
# However, occasionally built-in tools will
|
||||
# still be output to commentary.
|
||||
elif channel in ("commentary", "analysis") and recipient is not None:
|
||||
if recipient.startswith("functions."):
|
||||
function_name = recipient[len("functions.") :]
|
||||
return emit_function_call_delta_events(delta, function_name, state)
|
||||
elif recipient == "python":
|
||||
return emit_code_interpreter_delta_events(delta, state)
|
||||
elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
|
||||
return emit_mcp_delta_events(delta, state, recipient)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def emit_previous_item_done_events(
|
||||
previous_item: HarmonyMessage,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit done events for the previous item when expecting a new start.
|
||||
|
||||
This is a Harmony-specific dispatcher that extracts values from the
|
||||
Harmony parser's message object and delegates to shared leaf helpers.
|
||||
"""
|
||||
text = previous_item.content[0].text
|
||||
if previous_item.recipient is not None:
|
||||
# Deal with tool call
|
||||
if previous_item.recipient.startswith("functions."):
|
||||
function_name = previous_item.recipient[len("functions.") :]
|
||||
return emit_function_call_done_events(function_name, text, state)
|
||||
elif previous_item.recipient == "python":
|
||||
return emit_code_interpreter_completion_events(previous_item, state)
|
||||
elif (
|
||||
is_mcp_tool_by_namespace(previous_item.recipient)
|
||||
and state.current_item_id is not None
|
||||
and state.current_item_id.startswith("mcp_")
|
||||
):
|
||||
return emit_mcp_completion_events(previous_item.recipient, text, state)
|
||||
elif previous_item.channel == "analysis":
|
||||
return emit_reasoning_done_events(text, state)
|
||||
elif previous_item.channel in ("commentary", "final"):
|
||||
# Preambles (commentary with no recipient) and final messages
|
||||
# are both user-visible text.
|
||||
return emit_text_output_done_events(text, state)
|
||||
return []
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Harmony-specific tool lifecycle helpers
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def emit_browser_tool_events(
|
||||
previous_item: HarmonyMessage,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for browser tool calls (web search)."""
|
||||
function_name = previous_item.recipient[len("browser.") :]
|
||||
parsed_args = json.loads(previous_item.content[0].text)
|
||||
action = None
|
||||
|
||||
if function_name == "search":
|
||||
action = response_function_web_search.ActionSearch(
|
||||
type="search",
|
||||
query=parsed_args["query"],
|
||||
)
|
||||
elif function_name == "open":
|
||||
action = response_function_web_search.ActionOpenPage(
|
||||
type="open_page",
|
||||
# TODO: translate to url
|
||||
url=f"cursor:{parsed_args.get('cursor', '')}",
|
||||
)
|
||||
elif function_name == "find":
|
||||
action = response_function_web_search.ActionFind(
|
||||
type="find",
|
||||
pattern=parsed_args["pattern"],
|
||||
# TODO: translate to url
|
||||
url=f"cursor:{parsed_args.get('cursor', '')}",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown function name: {function_name}")
|
||||
|
||||
state.current_item_id = f"tool_{random_uuid()}"
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=response_function_web_search.ResponseFunctionWebSearch(
|
||||
# TODO: generate a unique id for web search call
|
||||
type="web_search_call",
|
||||
id=state.current_item_id,
|
||||
action=action,
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseWebSearchCallInProgressEvent(
|
||||
type="response.web_search_call.in_progress",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseWebSearchCallSearchingEvent(
|
||||
type="response.web_search_call.searching",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
# enqueue
|
||||
events.append(
|
||||
ResponseWebSearchCallCompletedEvent(
|
||||
type="response.web_search_call.completed",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseFunctionWebSearch(
|
||||
type="web_search_call",
|
||||
id=state.current_item_id,
|
||||
action=action,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_code_interpreter_completion_events(
|
||||
previous_item: HarmonyMessage,
|
||||
state: StreamingState,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events when code interpreter completes."""
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallCodeDoneEvent(
|
||||
type="response.code_interpreter_call_code.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
code=previous_item.content[0].text,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallInterpretingEvent(
|
||||
type="response.code_interpreter_call.interpreting",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseCodeInterpreterCallCompletedEvent(
|
||||
type="response.code_interpreter_call.completed",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item_id=state.current_item_id,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=state.current_output_index,
|
||||
item=ResponseCodeInterpreterToolCallParam(
|
||||
type="code_interpreter_call",
|
||||
id=state.current_item_id,
|
||||
code=previous_item.content[0].text,
|
||||
container_id="auto",
|
||||
outputs=[],
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def emit_tool_action_events(
|
||||
ctx: StreamingHarmonyContext,
|
||||
state: StreamingState,
|
||||
tool_server: ToolServer | None,
|
||||
) -> list[StreamingResponsesResponse]:
|
||||
"""Emit events for tool action turn."""
|
||||
if not ctx.is_assistant_action_turn() or len(ctx.parser.messages) == 0:
|
||||
return []
|
||||
|
||||
events: list[StreamingResponsesResponse] = []
|
||||
previous_item = ctx.parser.messages[-1]
|
||||
|
||||
# Handle browser tool
|
||||
if (
|
||||
tool_server is not None
|
||||
and tool_server.has_tool("browser")
|
||||
and previous_item.recipient is not None
|
||||
and previous_item.recipient.startswith("browser.")
|
||||
):
|
||||
events.extend(emit_browser_tool_events(previous_item, state))
|
||||
|
||||
# Handle tool completion
|
||||
if (
|
||||
tool_server is not None
|
||||
and previous_item.recipient is not None
|
||||
and state.current_item_id is not None
|
||||
and state.sent_output_item_added
|
||||
):
|
||||
recipient = previous_item.recipient
|
||||
if recipient == "python":
|
||||
events.extend(emit_code_interpreter_completion_events(previous_item, state))
|
||||
elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
|
||||
events.extend(
|
||||
emit_mcp_completion_events(
|
||||
recipient, previous_item.content[0].text, state
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
263
vllm/entrypoints/openai/responses/utils.py
Normal file
263
vllm/entrypoints/openai/responses/utils.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as FunctionCallTool,
|
||||
)
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
|
||||
from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
from openai.types.responses.tool import Tool
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponseInputOutputItem
|
||||
|
||||
|
||||
def should_continue_final_message(
|
||||
request_input: str | list[ResponseInputOutputItem],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the last input message is a partial assistant message
|
||||
that should be continued rather than starting a new generation.
|
||||
|
||||
This enables partial message completion similar to Anthropic's Messages API,
|
||||
where users can provide an incomplete assistant message and have the model
|
||||
continue from where it left off.
|
||||
|
||||
A message is considered partial if:
|
||||
1. It's a ResponseOutputMessage or ResponseReasoningItem
|
||||
2. Its status is "in_progress" or "incomplete"
|
||||
|
||||
Args:
|
||||
request_input: The input to the Responses API request
|
||||
|
||||
Returns:
|
||||
True if the final message should be continued, False otherwise
|
||||
"""
|
||||
if isinstance(request_input, str):
|
||||
# Simple string input is always a user message
|
||||
return False
|
||||
|
||||
if not request_input:
|
||||
return False
|
||||
|
||||
last_item = request_input[-1]
|
||||
|
||||
# Check if the last item is a partial assistant message
|
||||
if isinstance(last_item, ResponseOutputMessage):
|
||||
return last_item.status in ("in_progress", "incomplete")
|
||||
|
||||
# Check if the last item is a partial reasoning item
|
||||
if isinstance(last_item, ResponseReasoningItem):
|
||||
return last_item.status in ("in_progress", "incomplete")
|
||||
|
||||
if isinstance(last_item, dict):
|
||||
# only support partial completion for messages for now
|
||||
if last_item.get("type", "message") not in ("message", "reasoning"):
|
||||
return False
|
||||
return last_item.get("status") in ("in_progress", "incomplete")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def construct_input_messages(
|
||||
*,
|
||||
request_instructions: str | None = None,
|
||||
request_input: str | list[ResponseInputOutputItem],
|
||||
prev_msg: list[ChatCompletionMessageParam] | None = None,
|
||||
prev_response_output: list[ResponseOutputItem] | None = None,
|
||||
):
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
if request_instructions:
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": request_instructions,
|
||||
}
|
||||
)
|
||||
|
||||
# Prepend the conversation history.
|
||||
if prev_msg is not None:
|
||||
# Add the previous messages.
|
||||
messages.extend(prev_msg)
|
||||
if prev_response_output is not None:
|
||||
# Add the previous output.
|
||||
for output_item in prev_response_output:
|
||||
# NOTE: We skip the reasoning output.
|
||||
if isinstance(output_item, ResponseOutputMessage):
|
||||
for content in output_item.content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content.text,
|
||||
}
|
||||
)
|
||||
|
||||
# Append the new input.
|
||||
# Responses API supports simple text inputs without chat format.
|
||||
if isinstance(request_input, str):
|
||||
messages.append({"role": "user", "content": request_input})
|
||||
else:
|
||||
input_messages = construct_chat_messages_with_tool_call(request_input)
|
||||
messages.extend(input_messages)
|
||||
return messages
|
||||
|
||||
|
||||
def _maybe_combine_reasoning_and_tool_call(
|
||||
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
|
||||
) -> ChatCompletionMessageParam | None:
|
||||
"""Many models treat MCP calls and reasoning as a single message.
|
||||
This function checks if the last message is a reasoning message and
|
||||
the current message is a tool call"""
|
||||
if not (
|
||||
isinstance(item, ResponseFunctionToolCall)
|
||||
and item.id
|
||||
and item.id.startswith(MCP_PREFIX)
|
||||
):
|
||||
return None
|
||||
if len(messages) == 0:
|
||||
return None
|
||||
last_message = messages[-1]
|
||||
if not (
|
||||
last_message.get("role") == "assistant"
|
||||
and last_message.get("reasoning") is not None
|
||||
):
|
||||
return None
|
||||
|
||||
last_message["tool_calls"] = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=item.call_id,
|
||||
function=FunctionCallTool(
|
||||
name=item.name,
|
||||
arguments=item.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
return last_message
|
||||
|
||||
|
||||
def construct_chat_messages_with_tool_call(
|
||||
input_messages: list[ResponseInputOutputItem],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""This function wraps _construct_single_message_from_response_item
|
||||
Because some chatMessages come from multiple response items
|
||||
for example a reasoning item and a MCP tool call are two response items
|
||||
but are one chat message
|
||||
"""
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
for item in input_messages:
|
||||
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
if maybe_combined_message is not None:
|
||||
messages[-1] = maybe_combined_message
|
||||
else:
|
||||
messages.append(_construct_single_message_from_response_item(item))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _construct_single_message_from_response_item(
|
||||
item: ResponseInputOutputItem,
|
||||
) -> ChatCompletionMessageParam:
|
||||
if isinstance(item, ResponseFunctionToolCall):
|
||||
# Append the function call as a tool call.
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=item.call_id,
|
||||
function=FunctionCallTool(
|
||||
name=item.name,
|
||||
arguments=item.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
)
|
||||
elif isinstance(item, ResponseReasoningItem):
|
||||
reasoning_content = ""
|
||||
if item.encrypted_content:
|
||||
raise ValueError("Encrypted content is not supported.")
|
||||
if len(item.summary) == 1:
|
||||
reasoning_content = item.summary[0].text
|
||||
elif item.content and len(item.content) == 1:
|
||||
reasoning_content = item.content[0].text
|
||||
return {
|
||||
"role": "assistant",
|
||||
"reasoning": reasoning_content,
|
||||
}
|
||||
elif isinstance(item, ResponseOutputMessage):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": item.content[0].text,
|
||||
}
|
||||
elif isinstance(item, ResponseFunctionToolCallOutputItem):
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=item.output,
|
||||
tool_call_id=item.call_id,
|
||||
)
|
||||
elif isinstance(item, dict) and item.get("type") == "function_call_output":
|
||||
# Append the function call output as a tool message.
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=item.get("output"),
|
||||
tool_call_id=item.get("call_id"),
|
||||
)
|
||||
return item # type: ignore
|
||||
|
||||
|
||||
def extract_tool_types(tools: list[Tool]) -> set[str]:
|
||||
"""
|
||||
Extracts the tool types from the given tools.
|
||||
"""
|
||||
tool_types: set[str] = set()
|
||||
for tool in tools:
|
||||
if tool.type == "mcp":
|
||||
# Allow the MCP Tool type to enable built in tools if the
|
||||
# server_label is allowlisted in
|
||||
# envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS
|
||||
if tool.server_label in envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
|
||||
tool_types.add(tool.server_label)
|
||||
else:
|
||||
tool_types.add(tool.type)
|
||||
return tool_types
|
||||
|
||||
|
||||
def convert_tool_responses_to_completions_format(tool: dict) -> dict:
|
||||
"""
|
||||
Convert a flat tool schema:
|
||||
{"type": "function", "name": "...", "description": "...", "parameters": {...}}
|
||||
into:
|
||||
{"type": "function", "function": {...}}
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": tool,
|
||||
}
|
||||
|
||||
|
||||
def construct_tool_dicts(
|
||||
tools: list[Tool], tool_choice: ToolChoice
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if tools is None or (tool_choice == "none"):
|
||||
tool_dicts = None
|
||||
else:
|
||||
tool_dicts = [
|
||||
convert_tool_responses_to_completions_format(tool.model_dump())
|
||||
for tool in tools
|
||||
]
|
||||
return tool_dicts
|
||||
Reference in New Issue
Block a user