(gpt-oss, oai, chat): Remove Harmony Integration and Implement Native GPT-OSS Tool Call Support (#9043)
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Adapted from vLLM: https://github.com/vllm-project/vllm/blob/1b9902806915040ac9b3029f2ab7522ec505afc3/vllm/entrypoints/harmony_utils.py
|
||||||
|
# Slight differences in processing chat messages
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|||||||
@@ -174,7 +174,6 @@ async def lifespan(fast_api_app: FastAPI):
|
|||||||
tool_server=tool_server,
|
tool_server=tool_server,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print stack trace
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
@@ -859,15 +859,6 @@ class ResponseReasoningTextContent(BaseModel):
|
|||||||
type: Literal["reasoning_text"] = "reasoning_text"
|
type: Literal["reasoning_text"] = "reasoning_text"
|
||||||
|
|
||||||
|
|
||||||
class ResponseReasoningItem(BaseModel):
|
|
||||||
id: str
|
|
||||||
content: list[ResponseReasoningTextContent] = Field(default_factory=list)
|
|
||||||
summary: list = Field(default_factory=list)
|
|
||||||
type: Literal["reasoning"] = "reasoning"
|
|
||||||
encrypted_content: Optional[str] = None
|
|
||||||
status: Optional[Literal["in_progress", "completed", "incomplete"]]
|
|
||||||
|
|
||||||
|
|
||||||
ResponseInputOutputItem: TypeAlias = Union[
|
ResponseInputOutputItem: TypeAlias = Union[
|
||||||
ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall
|
ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -7,18 +7,8 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
from openai_harmony import Message as OpenAIMessage
|
|
||||||
|
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
from sglang.srt.entrypoints.harmony_utils import (
|
|
||||||
get_developer_message,
|
|
||||||
get_stop_tokens_for_assistant_actions,
|
|
||||||
get_streamable_parser_for_assistant,
|
|
||||||
get_system_message,
|
|
||||||
parse_chat_input,
|
|
||||||
parse_output_into_messages,
|
|
||||||
render_for_completion,
|
|
||||||
)
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
@@ -57,30 +47,12 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
"""Handler for /v1/chat/completions requests"""
|
"""Handler for /v1/chat/completions requests"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
|
self,
|
||||||
|
tokenizer_manager: TokenizerManager,
|
||||||
|
template_manager: TemplateManager,
|
||||||
):
|
):
|
||||||
super().__init__(tokenizer_manager)
|
super().__init__(tokenizer_manager)
|
||||||
self.template_manager = template_manager
|
self.template_manager = template_manager
|
||||||
self.use_harmony = (
|
|
||||||
self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_harmony:
|
|
||||||
from sglang.srt.function_call.harmony_tool_parser import (
|
|
||||||
HarmonyToolCallParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.harmony_tool_parser = HarmonyToolCallParser()
|
|
||||||
|
|
||||||
# NOTE While OpenAI's chat completion API supports browsing
|
|
||||||
# for some models, currently vLLM doesn't support it. Please use the
|
|
||||||
# Responses API instead.
|
|
||||||
self.supports_browsing = False
|
|
||||||
self.browser_tool = None
|
|
||||||
# NOTE: Chat completion API does not support code interpreter.
|
|
||||||
# Please use the Responses API instead.
|
|
||||||
self.supports_code_interpreter = False
|
|
||||||
self.python_tool = None
|
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "chatcmpl-"
|
return "chatcmpl-"
|
||||||
@@ -97,6 +69,18 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
):
|
):
|
||||||
return "Tools cannot be empty if tool choice is set to required."
|
return "Tools cannot be empty if tool choice is set to required."
|
||||||
|
|
||||||
|
max_output_tokens = request.max_completion_tokens or request.max_tokens
|
||||||
|
server_context_length = self.tokenizer_manager.server_args.context_length
|
||||||
|
if (
|
||||||
|
max_output_tokens
|
||||||
|
and server_context_length
|
||||||
|
and max_output_tokens > server_context_length
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
f"max_completion_tokens is too large: {max_output_tokens}."
|
||||||
|
f"This model supports at most {server_context_length} completion tokens."
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
@@ -107,66 +91,43 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||||
|
|
||||||
# Process messages and apply chat template
|
# Process messages and apply chat template
|
||||||
if not self.use_harmony:
|
processed_messages = self._process_messages(request, is_multimodal)
|
||||||
processed_messages = self._process_messages(request, is_multimodal)
|
|
||||||
|
|
||||||
# Build sampling parameters
|
# Build sampling parameters
|
||||||
sampling_params = self._build_sampling_params(
|
sampling_params = self._build_sampling_params(
|
||||||
request,
|
request,
|
||||||
processed_messages.stop,
|
processed_messages.stop,
|
||||||
processed_messages.tool_call_constraint,
|
processed_messages.tool_call_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle single vs multiple requests
|
# Handle single vs multiple requests
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
prompt_kwargs = {"text": processed_messages.prompt}
|
prompt_kwargs = {"text": processed_messages.prompt}
|
||||||
else:
|
|
||||||
if isinstance(processed_messages.prompt_ids, str):
|
|
||||||
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
|
||||||
else:
|
|
||||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
|
||||||
**prompt_kwargs,
|
|
||||||
image_data=processed_messages.image_data,
|
|
||||||
video_data=processed_messages.video_data,
|
|
||||||
audio_data=processed_messages.audio_data,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
return_logprob=request.logprobs,
|
|
||||||
logprob_start_len=-1,
|
|
||||||
top_logprobs_num=request.top_logprobs or 0,
|
|
||||||
stream=request.stream,
|
|
||||||
return_text_in_logprobs=True,
|
|
||||||
modalities=processed_messages.modalities,
|
|
||||||
lora_path=request.lora_path,
|
|
||||||
bootstrap_host=request.bootstrap_host,
|
|
||||||
bootstrap_port=request.bootstrap_port,
|
|
||||||
bootstrap_room=request.bootstrap_room,
|
|
||||||
return_hidden_states=request.return_hidden_states,
|
|
||||||
rid=request.rid,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
processed_messages, prompt_ids = self._make_request_with_harmony(request)
|
if isinstance(processed_messages.prompt_ids, str):
|
||||||
|
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
input_ids=prompt_ids,
|
**prompt_kwargs,
|
||||||
sampling_params=self._build_sampling_params(
|
image_data=processed_messages.image_data,
|
||||||
request,
|
video_data=processed_messages.video_data,
|
||||||
request.stop,
|
audio_data=processed_messages.audio_data,
|
||||||
tool_call_constraint=None,
|
sampling_params=sampling_params,
|
||||||
),
|
return_logprob=request.logprobs,
|
||||||
stream=request.stream,
|
logprob_start_len=-1,
|
||||||
return_logprob=request.logprobs,
|
top_logprobs_num=request.top_logprobs or 0,
|
||||||
logprob_start_len=-1,
|
stream=request.stream,
|
||||||
top_logprobs_num=request.top_logprobs or 0,
|
return_text_in_logprobs=True,
|
||||||
return_text_in_logprobs=True,
|
modalities=processed_messages.modalities,
|
||||||
lora_path=request.lora_path,
|
lora_path=request.lora_path,
|
||||||
bootstrap_host=request.bootstrap_host,
|
bootstrap_host=request.bootstrap_host,
|
||||||
bootstrap_port=request.bootstrap_port,
|
bootstrap_port=request.bootstrap_port,
|
||||||
bootstrap_room=request.bootstrap_room,
|
bootstrap_room=request.bootstrap_room,
|
||||||
return_hidden_states=request.return_hidden_states,
|
return_hidden_states=request.return_hidden_states,
|
||||||
rid=request.rid,
|
rid=request.rid,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, request
|
return adapted_request, request
|
||||||
|
|
||||||
@@ -251,14 +212,16 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
reasoning_effort=request.reasoning_effort,
|
||||||
|
builtin_tools=[],
|
||||||
**(
|
**(
|
||||||
request.chat_template_kwargs if request.chat_template_kwargs else {}
|
request.chat_template_kwargs if request.chat_template_kwargs else {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# This except branch will be triggered when the chosen model
|
# This except branch will be triggered when the chosen model
|
||||||
# has a different tools input format that is not compatible
|
# has a different tools input format that is not compatible
|
||||||
# with openAI's apply_chat_template tool_call format, like Mistral.
|
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||||
tools = (
|
tools = (
|
||||||
[t if "function" in t else {"function": t} for t in tools]
|
[t if "function" in t else {"function": t} for t in tools]
|
||||||
if tools
|
if tools
|
||||||
@@ -269,6 +232,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
reasoning_effort=request.reasoning_effort,
|
||||||
|
builtin_tools=[],
|
||||||
**(
|
**(
|
||||||
request.chat_template_kwargs if request.chat_template_kwargs else {}
|
request.chat_template_kwargs if request.chat_template_kwargs else {}
|
||||||
),
|
),
|
||||||
@@ -459,12 +424,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
cached_tokens = {}
|
cached_tokens = {}
|
||||||
hidden_states = {}
|
hidden_states = {}
|
||||||
|
|
||||||
# Harmony tracking
|
|
||||||
if self.use_harmony:
|
|
||||||
harmony_parsers = [
|
|
||||||
get_streamable_parser_for_assistant() for _ in range(request.n)
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for content in self.tokenizer_manager.generate_request(
|
async for content in self.tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
@@ -511,58 +470,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# Process content delta
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
if self.use_harmony:
|
delta = content["text"][len(stream_buffer) :]
|
||||||
harmony_parser = harmony_parsers[index]
|
stream_buffers[index] = stream_buffer + delta
|
||||||
|
|
||||||
new_token_ids = content["output_ids"]
|
|
||||||
for token_id in new_token_ids:
|
|
||||||
harmony_parser.process(token_id)
|
|
||||||
|
|
||||||
is_final = harmony_parser.current_channel == "final"
|
|
||||||
is_analysis = harmony_parser.current_channel == "analysis"
|
|
||||||
delta = harmony_parser.last_content_delta or ""
|
|
||||||
|
|
||||||
if is_analysis:
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(reasoning_content=delta),
|
|
||||||
finish_reason=None,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
created=int(time.time()),
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
||||||
continue
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(content=delta if delta else None),
|
|
||||||
finish_reason=None,
|
|
||||||
matched_stop=None,
|
|
||||||
logprobs=choice_logprobs,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
created=int(time.time()),
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
stream_buffer = stream_buffers.get(index, "")
|
|
||||||
delta = content["text"][len(stream_buffer) :]
|
|
||||||
stream_buffers[index] = stream_buffer + delta
|
|
||||||
|
|
||||||
# Handle reasoning content
|
# Handle reasoning content
|
||||||
if (
|
if (
|
||||||
self.tokenizer_manager.server_args.reasoning_parser
|
self.tokenizer_manager.server_args.reasoning_parser
|
||||||
and request.separate_reasoning
|
and request.separate_reasoning
|
||||||
and not self.use_harmony
|
|
||||||
):
|
):
|
||||||
reasoning_text, delta = self._process_reasoning_stream(
|
reasoning_text, delta = self._process_reasoning_stream(
|
||||||
index, delta, reasoning_parser_dict, content, request
|
index, delta, reasoning_parser_dict, content, request
|
||||||
@@ -581,27 +496,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
if self.use_harmony and not is_final:
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(reasoning_content=delta),
|
|
||||||
finish_reason=None,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
created=int(time.time()),
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
# Handle tool calls
|
# Handle tool calls
|
||||||
# TODO: support tool call parsing for harmony
|
if request.tool_choice != "none" and request.tools:
|
||||||
if (
|
|
||||||
request.tool_choice != "none"
|
|
||||||
and request.tools
|
|
||||||
and not self.use_harmony
|
|
||||||
):
|
|
||||||
async for chunk in self._process_tool_call_stream(
|
async for chunk in self._process_tool_call_stream(
|
||||||
index,
|
index,
|
||||||
delta,
|
delta,
|
||||||
@@ -765,76 +661,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||||
text = ret_item["text"]
|
text = ret_item["text"]
|
||||||
output_ids = ret_item["output_ids"]
|
|
||||||
|
|
||||||
if self.use_harmony:
|
|
||||||
parser = parse_output_into_messages(output_ids)
|
|
||||||
output_msgs = parser.messages
|
|
||||||
if len(output_msgs) == 0:
|
|
||||||
# The generation has stopped during reasoning.
|
|
||||||
is_tool_call = False
|
|
||||||
reasoning_content = parser.current_content
|
|
||||||
final_content = None
|
|
||||||
elif len(output_msgs) == 1:
|
|
||||||
# The generation has stopped during final message.
|
|
||||||
is_tool_call = False
|
|
||||||
reasoning_content = output_msgs[0].content[0].text
|
|
||||||
final_content = parser.current_content
|
|
||||||
else:
|
|
||||||
if len(output_msgs) != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"Expected 2 output messages (reasoning and final), "
|
|
||||||
f"but got {len(output_msgs)}."
|
|
||||||
)
|
|
||||||
reasoning_msg, final_msg = output_msgs
|
|
||||||
reasoning_content = reasoning_msg.content[0].text
|
|
||||||
final_content = final_msg.content[0].text
|
|
||||||
is_tool_call = final_msg.recipient is not None
|
|
||||||
|
|
||||||
if is_tool_call:
|
|
||||||
# Extract tool call information from final message
|
|
||||||
tool_call = (
|
|
||||||
self.harmony_tool_parser.extract_tool_calls_from_message(
|
|
||||||
final_msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tool_calls = [tool_call] if tool_call else []
|
|
||||||
|
|
||||||
message = ChatMessage(
|
|
||||||
role="assistant",
|
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
content=None, # Tool calls don't have regular content
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Normal message
|
|
||||||
message = ChatMessage(
|
|
||||||
role="assistant",
|
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
content=final_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_tool_call:
|
|
||||||
finish_reason_type = "tool_calls"
|
|
||||||
elif finish_reason:
|
|
||||||
finish_reason_type = (
|
|
||||||
finish_reason["type"] if finish_reason else "stop"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
finish_reason_type = "stop"
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
|
||||||
index=idx,
|
|
||||||
message=message,
|
|
||||||
logprobs=choice_logprobs,
|
|
||||||
finish_reason=finish_reason_type,
|
|
||||||
matched_stop=(
|
|
||||||
finish_reason["matched"]
|
|
||||||
if finish_reason and "matched" in finish_reason
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
choices.append(choice_data)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle reasoning content
|
# Handle reasoning content
|
||||||
reasoning_text = None
|
reasoning_text = None
|
||||||
@@ -1184,33 +1010,3 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
return f"data: {chunk.model_dump_json()}\n\n"
|
return f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _make_request_with_harmony(
|
|
||||||
self,
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
):
|
|
||||||
messages: list[OpenAIMessage] = []
|
|
||||||
|
|
||||||
# Add system message.
|
|
||||||
# In Chat Completion API, browsing is enabled by default if the model
|
|
||||||
# supports it.
|
|
||||||
assert not self.supports_browsing
|
|
||||||
assert not self.supports_code_interpreter
|
|
||||||
sys_msg = get_system_message(
|
|
||||||
reasoning_effort=request.reasoning_effort,
|
|
||||||
browser_description=None,
|
|
||||||
python_description=None,
|
|
||||||
)
|
|
||||||
messages.append(sys_msg)
|
|
||||||
|
|
||||||
# Add developer message.
|
|
||||||
dev_msg = get_developer_message()
|
|
||||||
messages.append(dev_msg)
|
|
||||||
|
|
||||||
# Add user message.
|
|
||||||
for chat_msg in request.messages:
|
|
||||||
messages.append(parse_chat_input(chat_msg))
|
|
||||||
|
|
||||||
# Render prompt token ids.
|
|
||||||
prompt_token_ids = render_for_completion(messages)
|
|
||||||
return messages, prompt_token_ids
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
|||||||
from sglang.srt.function_call.core_types import ToolCallItem
|
from sglang.srt.function_call.core_types import ToolCallItem
|
||||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||||
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
|
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
|
||||||
|
from sglang.srt.function_call.gpt_oss_detector import GptOssDetector
|
||||||
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
||||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
@@ -41,6 +42,7 @@ class FunctionCallParser:
|
|||||||
"qwen3_coder": Qwen3CoderDetector,
|
"qwen3_coder": Qwen3CoderDetector,
|
||||||
"glm45": Glm4MoeDetector,
|
"glm45": Glm4MoeDetector,
|
||||||
"step3": Step3Detector,
|
"step3": Step3Detector,
|
||||||
|
"gpt-oss": GptOssDetector,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||||
|
|||||||
331
python/sglang/srt/function_call/gpt_oss_detector.py
Normal file
331
python/sglang/srt/function_call/gpt_oss_detector.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import (
|
||||||
|
StreamingParseResult,
|
||||||
|
ToolCallItem,
|
||||||
|
_GetInfoFunc,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssDetector(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for T4-style function calls with channel format.
|
||||||
|
|
||||||
|
Supports two formats:
|
||||||
|
1. Direct function call: <|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{args}<|call|>
|
||||||
|
2. Commentary with action plan: <|channel|>commentary<|message|>{content}<|end|>
|
||||||
|
|
||||||
|
For parallel function calls, each call is self-contained and starts with its own channel:
|
||||||
|
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"SF"}<|call|>
|
||||||
|
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query":"SF attractions"}<|call|>
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Single: <|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"San Francisco"}<|call|>commentary
|
||||||
|
Multiple: <|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"Paris"}<|call|>commentary<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query":"Paris tourism"}<|call|>
|
||||||
|
With Action Plan: <|channel|>commentary<|message|>**Action plan**: 1. Do X 2. Do Y<|end|><|start|>assistant<|channel|>commentary to=functions.x<|constrain|>json<|message|>{"template": "basic_html", "path": "index.html"}<|call|>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "<|start|>assistant<|channel|>commentary"
|
||||||
|
self.eot_token = "<|call|>"
|
||||||
|
# TODO: no clear indication how parallel tool call response format is
|
||||||
|
self.tool_call_separator = ""
|
||||||
|
|
||||||
|
# Pattern for complete function calls with to= parameter
|
||||||
|
# Handles both <|call|> and <|call|>commentary endings
|
||||||
|
# Also handles optional <|start|>assistant prefix and whitespace after function name
|
||||||
|
self.function_call_pattern = re.compile(
|
||||||
|
r"(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*"
|
||||||
|
r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>(?:commentary)?",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pattern for streaming function calls (incomplete)
|
||||||
|
# Also handles optional whitespace after function name
|
||||||
|
self.streaming_pattern = re.compile(
|
||||||
|
r"(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*"
|
||||||
|
r"<\|constrain\|>json<\|message\|>(.*)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pattern for commentary with action plan (no to= parameter)
|
||||||
|
self.commentary_pattern = re.compile(
|
||||||
|
r"<\|channel\|>commentary<\|message\|>(.*?)<\|end\|>",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_arguments = ""
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if text contains TypeScript-style function call markers."""
|
||||||
|
return self.bot_token in text
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
"""Parse TypeScript-style function calls from complete text."""
|
||||||
|
if not self.has_tool_call(text):
|
||||||
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
|
||||||
|
tool_indices = self._get_tool_indices(tools)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
tool_index = 0
|
||||||
|
|
||||||
|
# Process the entire text to handle mixed commentary and tool calls
|
||||||
|
normal_text_parts = []
|
||||||
|
|
||||||
|
# Find all commentary sections (both with and without to=)
|
||||||
|
all_commentary_pattern = re.compile(
|
||||||
|
r"<\|channel\|>commentary(?:\s+to=[^<]*)?<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track processed positions to avoid double-processing
|
||||||
|
processed_ranges = []
|
||||||
|
|
||||||
|
# First, extract all tool calls
|
||||||
|
for match in self.function_call_pattern.finditer(text):
|
||||||
|
full_function_name = match.group(1)
|
||||||
|
args_content = match.group(2)
|
||||||
|
processed_ranges.append((match.start(), match.end()))
|
||||||
|
|
||||||
|
function_name = (
|
||||||
|
full_function_name.split(".")[-1]
|
||||||
|
if "." in full_function_name
|
||||||
|
else full_function_name
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
arguments = json.loads(args_content) if args_content.strip() else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if function_name in tool_indices:
|
||||||
|
calls.append(
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=tool_index,
|
||||||
|
name=function_name,
|
||||||
|
parameters=json.dumps(arguments, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_index += 1
|
||||||
|
|
||||||
|
# Then, find non-tool-call commentary sections for normal text
|
||||||
|
for match in all_commentary_pattern.finditer(text):
|
||||||
|
# Check if this match overlaps with any processed tool call
|
||||||
|
match_start, match_end = match.start(), match.end()
|
||||||
|
is_tool_call = any(
|
||||||
|
start <= match_start < end or start < match_end <= end
|
||||||
|
for start, end in processed_ranges
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this commentary is not part of a tool call, include it in normal text
|
||||||
|
if not is_tool_call:
|
||||||
|
content = match.group(1).strip()
|
||||||
|
if content:
|
||||||
|
normal_text_parts.append(content)
|
||||||
|
|
||||||
|
# Handle remaining text after all matches
|
||||||
|
if processed_ranges:
|
||||||
|
last_match_end = max(end for _, end in processed_ranges)
|
||||||
|
if last_match_end < len(text):
|
||||||
|
remaining_text = text[last_match_end:]
|
||||||
|
|
||||||
|
# Clean up <|start|>assistant prefixes and extract final content
|
||||||
|
# Remove standalone <|start|>assistant prefixes
|
||||||
|
remaining_text = re.sub(r"<\|start\|>assistant(?!\w)", "", remaining_text)
|
||||||
|
|
||||||
|
# Extract content from final channel if present
|
||||||
|
final_pattern = re.compile(
|
||||||
|
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)", re.DOTALL
|
||||||
|
)
|
||||||
|
final_match = final_pattern.search(remaining_text)
|
||||||
|
|
||||||
|
if final_match:
|
||||||
|
# Get everything before final channel + final channel content
|
||||||
|
before_final = remaining_text[: final_match.start()].strip()
|
||||||
|
final_content = final_match.group(1).strip()
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if before_final:
|
||||||
|
parts.append(before_final)
|
||||||
|
if final_content:
|
||||||
|
parts.append(final_content)
|
||||||
|
remaining_text = " ".join(parts) if parts else ""
|
||||||
|
|
||||||
|
remaining_text = remaining_text.strip()
|
||||||
|
|
||||||
|
if remaining_text:
|
||||||
|
normal_text_parts.append(remaining_text)
|
||||||
|
|
||||||
|
# Combine all normal text parts
|
||||||
|
final_normal_text = " ".join(part for part in normal_text_parts if part).strip()
|
||||||
|
return StreamingParseResult(normal_text=final_normal_text, calls=calls)
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Tool]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
"""Parse incremental streaming text for TypeScript-style function calls."""
|
||||||
|
self._buffer += new_text
|
||||||
|
current_text = self._buffer
|
||||||
|
|
||||||
|
# Check if we have a tool call
|
||||||
|
has_tool_call = "<|channel|>commentary to=" in current_text
|
||||||
|
|
||||||
|
if not has_tool_call and current_text:
|
||||||
|
# Check for commentary without function calls
|
||||||
|
commentary_match = self.commentary_pattern.search(current_text)
|
||||||
|
if commentary_match:
|
||||||
|
commentary_content = commentary_match.group(1)
|
||||||
|
self._buffer = current_text[commentary_match.end() :]
|
||||||
|
return StreamingParseResult(normal_text=commentary_content, calls=[])
|
||||||
|
|
||||||
|
# Check for final channel content
|
||||||
|
final_pattern = re.compile(
|
||||||
|
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
final_match = final_pattern.search(current_text)
|
||||||
|
if final_match:
|
||||||
|
final_content = final_match.group(1).strip()
|
||||||
|
self._buffer = ""
|
||||||
|
return StreamingParseResult(normal_text=final_content, calls=[])
|
||||||
|
|
||||||
|
self._buffer = ""
|
||||||
|
return StreamingParseResult(normal_text=new_text, calls=[])
|
||||||
|
|
||||||
|
if not hasattr(self, "_tool_indices"):
|
||||||
|
self._tool_indices = self._get_tool_indices(tools)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
try:
|
||||||
|
# Check for streaming function call
|
||||||
|
match = self.streaming_pattern.search(current_text)
|
||||||
|
if match:
|
||||||
|
full_function_name = match.group(1)
|
||||||
|
args_content = match.group(2)
|
||||||
|
|
||||||
|
function_name = (
|
||||||
|
full_function_name.split(".")[-1]
|
||||||
|
if "." in full_function_name
|
||||||
|
else full_function_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize state if this is the first tool call
|
||||||
|
if self.current_tool_id == -1:
|
||||||
|
self.current_tool_id = 0
|
||||||
|
self.prev_tool_call_arr = []
|
||||||
|
self.streamed_args_for_tool = [""]
|
||||||
|
|
||||||
|
# Ensure we have enough entries in tracking arrays
|
||||||
|
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||||
|
self.prev_tool_call_arr.append({})
|
||||||
|
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
|
||||||
|
if not self.current_tool_name_sent:
|
||||||
|
calls.append(
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self.current_tool_id,
|
||||||
|
name=function_name,
|
||||||
|
parameters="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.current_tool_name_sent = True
|
||||||
|
# Store the tool call info
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": {},
|
||||||
|
}
|
||||||
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||||
|
|
||||||
|
# Check if we have a complete function call
|
||||||
|
complete_match = self.function_call_pattern.search(current_text)
|
||||||
|
if complete_match:
|
||||||
|
args_content = complete_match.group(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(args_content)
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id][
|
||||||
|
"arguments"
|
||||||
|
] = parsed_args
|
||||||
|
|
||||||
|
# Send complete arguments if we haven't sent them yet
|
||||||
|
if not self.streamed_args_for_tool[self.current_tool_id]:
|
||||||
|
# Send the complete arguments as JSON string
|
||||||
|
calls.append(
|
||||||
|
ToolCallItem(
|
||||||
|
tool_index=self.current_tool_id,
|
||||||
|
name=None,
|
||||||
|
parameters=json.dumps(
|
||||||
|
parsed_args, ensure_ascii=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.streamed_args_for_tool[self.current_tool_id] = (
|
||||||
|
json.dumps(parsed_args, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Remove the completed function call from buffer
|
||||||
|
remaining_after_call = current_text[complete_match.end() :]
|
||||||
|
|
||||||
|
# Clean up <|start|>assistant prefixes and extract final content
|
||||||
|
remaining_after_call = re.sub(
|
||||||
|
r"<\|start\|>assistant(?!\w)", "", remaining_after_call
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract content from final channel if present
|
||||||
|
final_pattern = re.compile(
|
||||||
|
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
final_match = final_pattern.search(remaining_after_call)
|
||||||
|
|
||||||
|
if final_match:
|
||||||
|
before_final = remaining_after_call[
|
||||||
|
: final_match.start()
|
||||||
|
].strip()
|
||||||
|
final_content = final_match.group(1).strip()
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if before_final:
|
||||||
|
parts.append(before_final)
|
||||||
|
if final_content:
|
||||||
|
parts.append(final_content)
|
||||||
|
remaining_after_call = " ".join(parts) if parts else ""
|
||||||
|
|
||||||
|
self._buffer = remaining_after_call.strip()
|
||||||
|
|
||||||
|
# Reset state for next tool call
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
self.current_tool_id += 1
|
||||||
|
|
||||||
|
# Return final content if available
|
||||||
|
final_text = ""
|
||||||
|
if final_match and final_content:
|
||||||
|
final_text = final_content
|
||||||
|
elif remaining_after_call:
|
||||||
|
final_text = remaining_after_call
|
||||||
|
|
||||||
|
return StreamingParseResult(normal_text=final_text, calls=calls)
|
||||||
|
|
||||||
|
return StreamingParseResult(normal_text="", calls=calls)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||||
|
return StreamingParseResult(normal_text=current_text, calls=[])
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Harmony tool call parser for processing tool calls in harmony models."""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
|
||||||
ChatMessage,
|
|
||||||
FunctionResponse,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HarmonyToolCallParser:
|
|
||||||
"""Parser for extracting tool calls from harmony model outputs."""
|
|
||||||
|
|
||||||
def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
|
|
||||||
"""
|
|
||||||
Extract tool call from a single message if it's a tool call.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg: The harmony message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ToolCall if the message is a tool call, None otherwise
|
|
||||||
"""
|
|
||||||
if (
|
|
||||||
msg.channel == "commentary"
|
|
||||||
and msg.recipient
|
|
||||||
and msg.recipient.startswith("functions.")
|
|
||||||
):
|
|
||||||
function_name = msg.recipient.split(".")[-1]
|
|
||||||
arguments = msg.content[0].text if msg.content else "{}"
|
|
||||||
|
|
||||||
return ToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
|
||||||
function=FunctionResponse(
|
|
||||||
name=function_name,
|
|
||||||
arguments=arguments,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def process_streaming_chunk(
|
|
||||||
self,
|
|
||||||
harmony_parser,
|
|
||||||
index: int,
|
|
||||||
tool_call_trackers: dict,
|
|
||||||
stream_buffers: dict,
|
|
||||||
) -> Tuple[Optional[dict], bool, Optional[str]]:
|
|
||||||
"""
|
|
||||||
Process a streaming chunk for tool calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
harmony_parser: The harmony parser instance
|
|
||||||
index: The choice index
|
|
||||||
tool_call_trackers: Dict tracking tool calls per choice
|
|
||||||
stream_buffers: Dict for buffering content
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (tool_call_data, is_tool_call, delta)
|
|
||||||
"""
|
|
||||||
# Check if we're in a tool call
|
|
||||||
is_tool_call = (
|
|
||||||
harmony_parser.current_channel == "commentary"
|
|
||||||
and harmony_parser.current_recipient
|
|
||||||
and harmony_parser.current_recipient.startswith("functions.")
|
|
||||||
)
|
|
||||||
|
|
||||||
delta = harmony_parser.last_content_delta or ""
|
|
||||||
tool_call_data = None
|
|
||||||
|
|
||||||
if is_tool_call:
|
|
||||||
# Handle tool call streaming
|
|
||||||
function_name = harmony_parser.current_recipient.split(".")[-1]
|
|
||||||
|
|
||||||
# Track tool call indices per choice
|
|
||||||
if index not in tool_call_trackers:
|
|
||||||
tool_call_trackers[index] = {"count": 0, "current_function": None}
|
|
||||||
|
|
||||||
# Check if we just started a new tool call
|
|
||||||
tool_call_tracker = tool_call_trackers[index]
|
|
||||||
if tool_call_tracker["current_function"] != function_name:
|
|
||||||
# New tool call started
|
|
||||||
tool_call_tracker["current_function"] = function_name
|
|
||||||
tool_call_index = tool_call_tracker["count"]
|
|
||||||
tool_call_tracker["count"] += 1
|
|
||||||
|
|
||||||
# Store the tool call index for this function
|
|
||||||
tool_call_key = f"{index}_{function_name}"
|
|
||||||
stream_buffers[tool_call_key] = {
|
|
||||||
"index": tool_call_index,
|
|
||||||
"content": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_call_data = {
|
|
||||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
|
||||||
"index": tool_call_index,
|
|
||||||
"function_name": function_name,
|
|
||||||
"arguments": delta,
|
|
||||||
"is_first_chunk": True,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Subsequent chunks for the same tool call
|
|
||||||
tool_call_key = f"{index}_{function_name}"
|
|
||||||
tool_call_index = stream_buffers[tool_call_key]["index"]
|
|
||||||
|
|
||||||
tool_call_data = {
|
|
||||||
"id": None,
|
|
||||||
"index": tool_call_index,
|
|
||||||
"function_name": None,
|
|
||||||
"arguments": delta,
|
|
||||||
"is_first_chunk": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
stream_buffers[tool_call_key]["content"] += delta
|
|
||||||
|
|
||||||
return tool_call_data, is_tool_call, delta
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
from typing import Dict, Optional, Tuple, Type
|
from typing import Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
|
||||||
@@ -185,6 +186,320 @@ class KimiDetector(BaseReasoningFormatDetector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssDetector(BaseReasoningFormatDetector):
|
||||||
|
"""
|
||||||
|
Detector for T4-style reasoning format.
|
||||||
|
|
||||||
|
Assumes reasoning format with two channels:
|
||||||
|
<|channel|>analysis<|message|>...reasoning content...<|end|>
|
||||||
|
<|start|>assistant<|channel|>final<|message|>...final answer...<|return|>
|
||||||
|
|
||||||
|
Returns content from 'analysis' channel as reasoning_text
|
||||||
|
and content from 'final' channel as normal_text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_reasoning (bool): If False, accumulates reasoning content until complete.
|
||||||
|
If True, streams reasoning content as it arrives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = True):
|
||||||
|
# TypeScript uses channel tokens instead of simple start/end tokens
|
||||||
|
super().__init__(
|
||||||
|
"<|channel|>analysis<|message|>",
|
||||||
|
"<|end|>",
|
||||||
|
force_reasoning=True,
|
||||||
|
stream_reasoning=stream_reasoning,
|
||||||
|
)
|
||||||
|
self.final_channel_start = "<|start|>assistant<|channel|>final<|message|>"
|
||||||
|
self.final_channel_end = "<|return|>"
|
||||||
|
self._in_final_channel = False
|
||||||
|
self._analysis_complete = False
|
||||||
|
self._in_reasoning = True
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
One-time parsing: Detects and parses both analysis and final channels.
|
||||||
|
Tool call channels are preserved in normal_text for downstream processing.
|
||||||
|
|
||||||
|
HACK: Also handles simplified format where text starts with "analysis" and transitions
|
||||||
|
to "assistantfinal" without full channel markers.
|
||||||
|
"""
|
||||||
|
# HACK: Handle simplified format (analysis...assistantfinal) without channel markers
|
||||||
|
if (
|
||||||
|
text.startswith("analysis")
|
||||||
|
and "assistantfinal" in text
|
||||||
|
and "<|channel|>" not in text
|
||||||
|
):
|
||||||
|
# Split on "assistantfinal"
|
||||||
|
parts = text.split("assistantfinal", 1)
|
||||||
|
self._in_reasoning = False
|
||||||
|
if len(parts) == 2:
|
||||||
|
reasoning_text = parts[0][
|
||||||
|
len("analysis") :
|
||||||
|
].strip() # Remove "analysis" prefix
|
||||||
|
normal_text = parts[1].strip()
|
||||||
|
return StreamingParseResult(
|
||||||
|
normal_text=normal_text, reasoning_text=reasoning_text
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_parts = []
|
||||||
|
normal_parts = []
|
||||||
|
current_pos = 0
|
||||||
|
|
||||||
|
# Process text sequentially to preserve tool calls between analysis sections
|
||||||
|
while current_pos < len(text):
|
||||||
|
# Look for next analysis channel
|
||||||
|
analysis_start_idx = text.find(self.think_start_token, current_pos)
|
||||||
|
|
||||||
|
if analysis_start_idx == -1:
|
||||||
|
# No more analysis channels, rest goes to remaining
|
||||||
|
break
|
||||||
|
|
||||||
|
# Preserve any content before this analysis channel (could include tool calls)
|
||||||
|
if analysis_start_idx > current_pos:
|
||||||
|
between_content = text[current_pos:analysis_start_idx]
|
||||||
|
# This content will be added to normal_parts later
|
||||||
|
normal_parts.append(between_content)
|
||||||
|
|
||||||
|
# Extract analysis content
|
||||||
|
analysis_content_start = analysis_start_idx + len(self.think_start_token)
|
||||||
|
analysis_end_idx = text.find(self.think_end_token, analysis_content_start)
|
||||||
|
|
||||||
|
if analysis_end_idx != -1:
|
||||||
|
reasoning_parts.append(
|
||||||
|
text[analysis_content_start:analysis_end_idx].strip()
|
||||||
|
)
|
||||||
|
current_pos = analysis_end_idx + len(self.think_end_token)
|
||||||
|
else:
|
||||||
|
# Analysis not complete
|
||||||
|
reasoning_parts.append(text[analysis_content_start:].strip())
|
||||||
|
reasoning_text = "".join(reasoning_parts)
|
||||||
|
return StreamingParseResult(reasoning_text=reasoning_text)
|
||||||
|
|
||||||
|
# Add any remaining text after all analysis sections
|
||||||
|
if current_pos < len(text):
|
||||||
|
remaining = text[current_pos:]
|
||||||
|
normal_parts.append(remaining)
|
||||||
|
|
||||||
|
# Process non-analysis content for commentary sections
|
||||||
|
full_normal_text = "".join(normal_parts)
|
||||||
|
|
||||||
|
# Extract reasoning from non-tool-call commentary sections
|
||||||
|
# Tool calls have "to=" in their header, regular commentary does not
|
||||||
|
commentary_pattern = re.compile(
|
||||||
|
r"<\|start\|>assistant<\|channel\|>commentary<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
cleaned_text = full_normal_text
|
||||||
|
for match in reversed(list(commentary_pattern.finditer(full_normal_text))):
|
||||||
|
# Check if this commentary is a tool call by looking at the text before <|message|>
|
||||||
|
match_start = match.start()
|
||||||
|
# Find the start of this commentary section
|
||||||
|
commentary_start = full_normal_text.rfind(
|
||||||
|
"<|channel|>commentary", 0, match_start
|
||||||
|
)
|
||||||
|
if commentary_start != -1:
|
||||||
|
# Extract text between "commentary" and "<|message|>"
|
||||||
|
message_pos = full_normal_text.find("<|message|>", commentary_start)
|
||||||
|
if message_pos != -1:
|
||||||
|
between_text = full_normal_text[commentary_start:message_pos]
|
||||||
|
# If no "to=" found, this is regular commentary (reasoning content)
|
||||||
|
if " to=" not in between_text:
|
||||||
|
content = match.group(1).strip()
|
||||||
|
reasoning_parts.append(content)
|
||||||
|
# Remove this commentary section from normal text
|
||||||
|
cleaned_text = (
|
||||||
|
cleaned_text[: match.start()] + cleaned_text[match.end() :]
|
||||||
|
)
|
||||||
|
|
||||||
|
full_normal_text = cleaned_text
|
||||||
|
|
||||||
|
# Combine all reasoning parts
|
||||||
|
reasoning_text = "".join(reasoning_parts)
|
||||||
|
|
||||||
|
# Process full_normal_text for final output
|
||||||
|
normal_text = ""
|
||||||
|
if self.final_channel_start in full_normal_text:
|
||||||
|
final_start = full_normal_text.find(self.final_channel_start)
|
||||||
|
final_content_start = final_start + len(self.final_channel_start)
|
||||||
|
final_end = full_normal_text.find(
|
||||||
|
self.final_channel_end, final_content_start
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_end != -1:
|
||||||
|
# Extract content before final channel (includes tool calls)
|
||||||
|
before_final = full_normal_text[:final_start].strip()
|
||||||
|
# Extract ONLY the final channel content (not the channel markers)
|
||||||
|
final_text = full_normal_text[final_content_start:final_end].strip()
|
||||||
|
# Extract content after final channel
|
||||||
|
after_final = full_normal_text[
|
||||||
|
final_end + len(self.final_channel_end) :
|
||||||
|
].strip()
|
||||||
|
|
||||||
|
# For tool calls + final answer: concatenate tool calls with final text
|
||||||
|
parts = []
|
||||||
|
if before_final:
|
||||||
|
parts.append(before_final)
|
||||||
|
if final_text:
|
||||||
|
parts.append(final_text)
|
||||||
|
if after_final:
|
||||||
|
parts.append(after_final)
|
||||||
|
normal_text = " ".join(parts)
|
||||||
|
else:
|
||||||
|
# Final channel not complete - extract what we have
|
||||||
|
# Look for just <|channel|>final<|message|> without <|return|>
|
||||||
|
alt_final_start = full_normal_text.find("<|channel|>final<|message|>")
|
||||||
|
if alt_final_start != -1:
|
||||||
|
before_alt_final = full_normal_text[:alt_final_start].strip()
|
||||||
|
alt_final_content = full_normal_text[
|
||||||
|
alt_final_start + len("<|channel|>final<|message|>") :
|
||||||
|
].strip()
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if before_alt_final:
|
||||||
|
parts.append(before_alt_final)
|
||||||
|
if alt_final_content:
|
||||||
|
parts.append(alt_final_content)
|
||||||
|
normal_text = " ".join(parts)
|
||||||
|
else:
|
||||||
|
normal_text = full_normal_text.strip()
|
||||||
|
else:
|
||||||
|
# No final channel, treat all as normal text (includes tool calls)
|
||||||
|
normal_text = full_normal_text.strip()
|
||||||
|
|
||||||
|
return StreamingParseResult(
|
||||||
|
normal_text=normal_text, reasoning_text=reasoning_text
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing for GPT-OSS format.
|
||||||
|
|
||||||
|
This is a simplified streaming implementation that accumulates content
|
||||||
|
and delegates to the non-streaming parser for complex multi-channel parsing.
|
||||||
|
TODO: Implement proper incremental parsing for better streaming performance.
|
||||||
|
"""
|
||||||
|
self._buffer += new_text
|
||||||
|
|
||||||
|
if not self._in_reasoning:
|
||||||
|
return StreamingParseResult(normal_text=new_text)
|
||||||
|
|
||||||
|
# Check if we have complete sections to process
|
||||||
|
# For GPT-OSS, we need to wait for complete channel sections
|
||||||
|
# HACK: For now, use simplified approach - wait for key markers before processing
|
||||||
|
key_markers = ["<|end|>", "<|call|>", "<|return|>", "assistantfinal"]
|
||||||
|
has_complete_section = any(marker in self._buffer for marker in key_markers)
|
||||||
|
|
||||||
|
if not has_complete_section:
|
||||||
|
# Still accumulating, don't process yet
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
# Handle simplified format (analysis...assistantfinal) with true incremental streaming
|
||||||
|
if (
|
||||||
|
"<|channel|>" not in self._buffer
|
||||||
|
): # Simplified format without channel markers
|
||||||
|
if self._buffer.startswith("analysis"):
|
||||||
|
# Check if we have the transition to assistantfinal
|
||||||
|
if "assistantfinal" in self._buffer:
|
||||||
|
self._in_reasoning = False
|
||||||
|
# Complete reasoning section - extract and stream it
|
||||||
|
parts = self._buffer.split("assistantfinal", 1)
|
||||||
|
reasoning_text = parts[0][len("analysis") :].strip()
|
||||||
|
final_content = parts[1].strip()
|
||||||
|
|
||||||
|
# Clear buffer and return both reasoning and final content
|
||||||
|
self._buffer = ""
|
||||||
|
return StreamingParseResult(
|
||||||
|
reasoning_text=reasoning_text if self.stream_reasoning else "",
|
||||||
|
normal_text=final_content,
|
||||||
|
)
|
||||||
|
elif self.stream_reasoning:
|
||||||
|
# Stream reasoning content incrementally as it arrives
|
||||||
|
current_reasoning = self._buffer[len("analysis") :].strip()
|
||||||
|
self._buffer = ""
|
||||||
|
return StreamingParseResult(reasoning_text=current_reasoning)
|
||||||
|
else:
|
||||||
|
# Wait for assistantfinal
|
||||||
|
return StreamingParseResult()
|
||||||
|
elif self._buffer.startswith("assistantfinal"):
|
||||||
|
# Direct final content without analysis
|
||||||
|
final_content = self._buffer[len("assistantfinal") :].strip()
|
||||||
|
self._buffer = ""
|
||||||
|
return StreamingParseResult(normal_text=final_content)
|
||||||
|
|
||||||
|
# For full channel format, process sections as they complete
|
||||||
|
result = StreamingParseResult()
|
||||||
|
|
||||||
|
# Process complete analysis sections
|
||||||
|
while (
|
||||||
|
self.think_start_token in self._buffer
|
||||||
|
and self.think_end_token in self._buffer
|
||||||
|
):
|
||||||
|
start_idx = self._buffer.find(self.think_start_token)
|
||||||
|
start_pos = start_idx + len(self.think_start_token)
|
||||||
|
end_pos = self._buffer.find(self.think_end_token, start_pos)
|
||||||
|
|
||||||
|
if end_pos != -1:
|
||||||
|
reasoning_content = self._buffer[start_pos:end_pos].strip()
|
||||||
|
if self.stream_reasoning and reasoning_content:
|
||||||
|
result.reasoning_text += reasoning_content
|
||||||
|
|
||||||
|
# Remove processed analysis section
|
||||||
|
self._buffer = (
|
||||||
|
self._buffer[:start_idx]
|
||||||
|
+ self._buffer[end_pos + len(self.think_end_token) :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process complete commentary sections
|
||||||
|
commentary_pattern = re.compile(
|
||||||
|
r"<\|start\|>assistant<\|channel\|>commentary<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
for match in reversed(list(commentary_pattern.finditer(self._buffer))):
|
||||||
|
# Check if this is a tool call
|
||||||
|
start_pos = match.start()
|
||||||
|
commentary_content = match.group(1).strip()
|
||||||
|
if self.stream_reasoning and commentary_content:
|
||||||
|
result.reasoning_text += commentary_content
|
||||||
|
|
||||||
|
# Remove this commentary section
|
||||||
|
self._buffer = self._buffer[: match.start()] + self._buffer[match.end() :]
|
||||||
|
# Clean up any standalone <|start|>assistant
|
||||||
|
self._buffer = re.sub(
|
||||||
|
r"<\|start\|>assistant(?=<\|start\|>assistant)", "", self._buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle final channel completion
|
||||||
|
if self.final_channel_start in self._buffer:
|
||||||
|
final_start = self._buffer.find(self.final_channel_start)
|
||||||
|
final_content_start = final_start + len(self.final_channel_start)
|
||||||
|
|
||||||
|
# Check if final channel is complete
|
||||||
|
final_end = self._buffer.find(self.final_channel_end, final_content_start)
|
||||||
|
if final_end != -1:
|
||||||
|
# Complete final channel - process everything
|
||||||
|
final_result = self.detect_and_parse(self._buffer)
|
||||||
|
self._buffer = ""
|
||||||
|
return StreamingParseResult(
|
||||||
|
normal_text=final_result.normal_text,
|
||||||
|
reasoning_text=result.reasoning_text + final_result.reasoning_text,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Extract content before final channel (e.g. tool calls)
|
||||||
|
before_final = self._buffer[:final_start]
|
||||||
|
if before_final:
|
||||||
|
# Output tool calls for processing
|
||||||
|
result.normal_text += before_final
|
||||||
|
# Keep the final channel part in buffer
|
||||||
|
self._buffer = self._buffer[final_start:]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ReasoningParser:
|
class ReasoningParser:
|
||||||
"""
|
"""
|
||||||
Parser that handles both streaming and non-streaming scenarios for extracting
|
Parser that handles both streaming and non-streaming scenarios for extracting
|
||||||
@@ -203,6 +518,7 @@ class ReasoningParser:
|
|||||||
"glm45": Qwen3Detector,
|
"glm45": Qwen3Detector,
|
||||||
"kimi": KimiDetector,
|
"kimi": KimiDetector,
|
||||||
"step3": DeepSeekR1Detector,
|
"step3": DeepSeekR1Detector,
|
||||||
|
"gpt-oss": GptOssDetector,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1190,7 +1190,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tool-call-parser",
|
"--tool-call-parser",
|
||||||
type=str,
|
type=str,
|
||||||
choices=[
|
choices=[ # TODO: use FunctionCallParser.DetectorMap.keys()
|
||||||
"qwen25",
|
"qwen25",
|
||||||
"mistral",
|
"mistral",
|
||||||
"llama3",
|
"llama3",
|
||||||
@@ -1200,6 +1200,7 @@ class ServerArgs:
|
|||||||
"qwen3_coder",
|
"qwen3_coder",
|
||||||
"glm45",
|
"glm45",
|
||||||
"step3",
|
"step3",
|
||||||
|
"gpt-oss",
|
||||||
],
|
],
|
||||||
default=ServerArgs.tool_call_parser,
|
default=ServerArgs.tool_call_parser,
|
||||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
|
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
|
||||||
|
|||||||
Reference in New Issue
Block a user