misc: Improvement to serving_chat.py and add more ut (#7489)
This commit is contained in:
@@ -14,7 +14,8 @@
|
||||
"""Pydantic models for OpenAI API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -587,3 +588,30 @@ OpenAIServingRequest = Union[
|
||||
ScoringRequest,
|
||||
V1RerankReqInput,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageProcessingResult:
|
||||
"""Result of processing chat messages and applying templates.
|
||||
|
||||
This dataclass encapsulates all the outputs from message processing including
|
||||
prompt generation, multimodal data extraction, and constraint preparation.
|
||||
Used internally by OpenAIServingChat to pass processed data between methods.
|
||||
|
||||
Args:
|
||||
prompt: The final text prompt after applying chat template
|
||||
prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
|
||||
image_data: Extracted image data from messages, if any
|
||||
audio_data: Extracted audio data from messages, if any
|
||||
modalities: List of modality types present in the messages
|
||||
stop: Combined stop strings from template and request
|
||||
tool_call_constraint: Optional constraint for structured tool calls
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
prompt_ids: Union[str, List[int]]
|
||||
image_data: Optional[Any]
|
||||
audio_data: Optional[Any]
|
||||
modalities: List[str]
|
||||
stop: List[str]
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
FunctionResponse,
|
||||
LogProbs,
|
||||
MessageProcessingResult,
|
||||
ToolCall,
|
||||
TopLogprob,
|
||||
)
|
||||
@@ -62,41 +63,33 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||
|
||||
# Process messages and apply chat template
|
||||
(
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
) = self._process_messages(request, is_multimodal)
|
||||
processed_messages = self._process_messages(request, is_multimodal)
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request, stop, tool_call_constraint
|
||||
request, processed_messages.stop, processed_messages.tool_call_constraint
|
||||
)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": prompt}
|
||||
prompt_kwargs = {"text": processed_messages.prompt}
|
||||
else:
|
||||
if isinstance(prompt_ids, str):
|
||||
prompt_kwargs = {"text": prompt_ids}
|
||||
if isinstance(processed_messages.prompt_ids, str):
|
||||
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt_ids}
|
||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
image_data=processed_messages.image_data,
|
||||
audio_data=processed_messages.audio_data,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=request.logprobs,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=request.top_logprobs or 0,
|
||||
stream=request.stream,
|
||||
return_text_in_logprobs=True,
|
||||
modalities=modalities,
|
||||
modalities=processed_messages.modalities,
|
||||
lora_path=request.lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
@@ -108,74 +101,42 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
def _process_messages(
|
||||
self, request: ChatCompletionRequest, is_multimodal: bool
|
||||
) -> tuple[
|
||||
str,
|
||||
Union[str, List[int]],
|
||||
Optional[Any],
|
||||
Optional[Any],
|
||||
List[str],
|
||||
List[str],
|
||||
Optional[Any],
|
||||
]:
|
||||
) -> MessageProcessingResult:
|
||||
"""Process chat messages and apply chat template"""
|
||||
tool_call_constraint = None
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
for item in request.tools
|
||||
if item.function.name == request.tool_choice.function.name
|
||||
]
|
||||
else:
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
tool_call_constraint = parser.get_structure_constraint(
|
||||
request.tool_choice
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
if self.template_manager.chat_template_name is None:
|
||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||
self._apply_jinja_template(request, tools, is_multimodal)
|
||||
)
|
||||
# Apply chat template and its stop strings
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
for item in request.tools
|
||||
if item.function.name == request.tool_choice.function.name
|
||||
]
|
||||
else:
|
||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||
self._apply_conversation_template(request, is_multimodal)
|
||||
)
|
||||
else:
|
||||
# Use raw prompt
|
||||
prompt_ids = request.messages
|
||||
stop = request.stop or []
|
||||
image_data = None
|
||||
audio_data = None
|
||||
modalities = []
|
||||
prompt = request.messages
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
return (
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
)
|
||||
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
|
||||
|
||||
# Use chat template
|
||||
if self.template_manager.chat_template_name is None:
|
||||
result = self._apply_jinja_template(request, tools, is_multimodal)
|
||||
else:
|
||||
result = self._apply_conversation_template(request, is_multimodal)
|
||||
|
||||
result.tool_call_constraint = tool_call_constraint
|
||||
return result
|
||||
|
||||
def _apply_jinja_template(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
tools: Optional[List[Dict]],
|
||||
is_multimodal: bool,
|
||||
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
|
||||
) -> MessageProcessingResult:
|
||||
"""Apply Jinja chat template"""
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
@@ -253,13 +214,20 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
image_data = image_data if image_data else None
|
||||
audio_data = audio_data if audio_data else None
|
||||
modalities = modalities if modalities else []
|
||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||
return MessageProcessingResult(
|
||||
prompt=prompt,
|
||||
prompt_ids=prompt_ids,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
modalities=modalities,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _apply_conversation_template(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
is_multimodal: bool,
|
||||
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
|
||||
) -> MessageProcessingResult:
|
||||
"""Apply conversation template"""
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
@@ -304,7 +272,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if not is_multimodal:
|
||||
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
||||
|
||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||
return MessageProcessingResult(
|
||||
prompt=prompt,
|
||||
prompt_ids=prompt_ids,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
modalities=modalities,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _build_sampling_params(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,10 @@ from unittest.mock import Mock, patch
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
MessageProcessingResult,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
@@ -104,7 +107,7 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
conv_ins.stop_str = ["</s>"]
|
||||
conv_mock.return_value = conv_ins
|
||||
|
||||
proc_mock.return_value = (
|
||||
proc_mock.return_value = MessageProcessingResult(
|
||||
"Test prompt",
|
||||
[1, 2, 3],
|
||||
None,
|
||||
@@ -119,6 +122,59 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
def test_stop_str_isolation_between_requests(self):
|
||||
"""Test that stop strings from one request don't affect subsequent requests.
|
||||
|
||||
This tests the fix for the bug where conv.stop_str was being mutated globally,
|
||||
causing stop strings from one request to persist in subsequent requests.
|
||||
"""
|
||||
# Mock conversation template with initial stop_str
|
||||
initial_stop_str = ["\n"]
|
||||
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||
) as conv_mock:
|
||||
# Create a mock conversation object that will be returned by generate_chat_conv
|
||||
conv_ins = Mock()
|
||||
conv_ins.get_prompt.return_value = "Test prompt"
|
||||
conv_ins.image_data = None
|
||||
conv_ins.audio_data = None
|
||||
conv_ins.modalities = []
|
||||
conv_ins.stop_str = (
|
||||
initial_stop_str.copy()
|
||||
) # Template's default stop strings
|
||||
conv_mock.return_value = conv_ins
|
||||
|
||||
# First request with additional stop string
|
||||
req1 = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "First request"}],
|
||||
stop=["CUSTOM_STOP"],
|
||||
)
|
||||
|
||||
# Call the actual _apply_conversation_template method (not mocked)
|
||||
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
|
||||
|
||||
# Verify first request has both stop strings
|
||||
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
|
||||
self.assertEqual(result1.stop, expected_stop1)
|
||||
|
||||
# Verify the original template's stop_str wasn't mutated after first request
|
||||
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||
|
||||
# Second request without additional stop string
|
||||
req2 = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Second request"}],
|
||||
# No custom stop strings
|
||||
)
|
||||
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
|
||||
|
||||
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
|
||||
self.assertEqual(result2.stop, initial_stop_str)
|
||||
self.assertNotIn("CUSTOM_STOP", result2.stop)
|
||||
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||
|
||||
# ------------- sampling-params -------------
|
||||
def test_sampling_param_build(self):
|
||||
req = ChatCompletionRequest(
|
||||
|
||||
Reference in New Issue
Block a user