misc: Improvement to serving_chat.py and add more ut (#7489)
This commit is contained in:
@@ -14,7 +14,8 @@
|
|||||||
"""Pydantic models for OpenAI API protocol"""
|
"""Pydantic models for OpenAI API protocol"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Union
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -587,3 +588,30 @@ OpenAIServingRequest = Union[
|
|||||||
ScoringRequest,
|
ScoringRequest,
|
||||||
V1RerankReqInput,
|
V1RerankReqInput,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageProcessingResult:
|
||||||
|
"""Result of processing chat messages and applying templates.
|
||||||
|
|
||||||
|
This dataclass encapsulates all the outputs from message processing including
|
||||||
|
prompt generation, multimodal data extraction, and constraint preparation.
|
||||||
|
Used internally by OpenAIServingChat to pass processed data between methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The final text prompt after applying chat template
|
||||||
|
prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
|
||||||
|
image_data: Extracted image data from messages, if any
|
||||||
|
audio_data: Extracted audio data from messages, if any
|
||||||
|
modalities: List of modality types present in the messages
|
||||||
|
stop: Combined stop strings from template and request
|
||||||
|
tool_call_constraint: Optional constraint for structured tool calls
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
prompt_ids: Union[str, List[int]]
|
||||||
|
image_data: Optional[Any]
|
||||||
|
audio_data: Optional[Any]
|
||||||
|
modalities: List[str]
|
||||||
|
stop: List[str]
|
||||||
|
tool_call_constraint: Optional[Any] = None
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FunctionResponse,
|
FunctionResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
|
MessageProcessingResult,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
TopLogprob,
|
TopLogprob,
|
||||||
)
|
)
|
||||||
@@ -62,41 +63,33 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||||
|
|
||||||
# Process messages and apply chat template
|
# Process messages and apply chat template
|
||||||
(
|
processed_messages = self._process_messages(request, is_multimodal)
|
||||||
prompt,
|
|
||||||
prompt_ids,
|
|
||||||
image_data,
|
|
||||||
audio_data,
|
|
||||||
modalities,
|
|
||||||
stop,
|
|
||||||
tool_call_constraint,
|
|
||||||
) = self._process_messages(request, is_multimodal)
|
|
||||||
|
|
||||||
# Build sampling parameters
|
# Build sampling parameters
|
||||||
sampling_params = self._build_sampling_params(
|
sampling_params = self._build_sampling_params(
|
||||||
request, stop, tool_call_constraint
|
request, processed_messages.stop, processed_messages.tool_call_constraint
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle single vs multiple requests
|
# Handle single vs multiple requests
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
prompt_kwargs = {"text": prompt}
|
prompt_kwargs = {"text": processed_messages.prompt}
|
||||||
else:
|
else:
|
||||||
if isinstance(prompt_ids, str):
|
if isinstance(processed_messages.prompt_ids, str):
|
||||||
prompt_kwargs = {"text": prompt_ids}
|
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompt_ids}
|
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
image_data=image_data,
|
image_data=processed_messages.image_data,
|
||||||
audio_data=audio_data,
|
audio_data=processed_messages.audio_data,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
return_logprob=request.logprobs,
|
return_logprob=request.logprobs,
|
||||||
logprob_start_len=-1,
|
logprob_start_len=-1,
|
||||||
top_logprobs_num=request.top_logprobs or 0,
|
top_logprobs_num=request.top_logprobs or 0,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
modalities=modalities,
|
modalities=processed_messages.modalities,
|
||||||
lora_path=request.lora_path,
|
lora_path=request.lora_path,
|
||||||
bootstrap_host=request.bootstrap_host,
|
bootstrap_host=request.bootstrap_host,
|
||||||
bootstrap_port=request.bootstrap_port,
|
bootstrap_port=request.bootstrap_port,
|
||||||
@@ -108,74 +101,42 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
def _process_messages(
|
def _process_messages(
|
||||||
self, request: ChatCompletionRequest, is_multimodal: bool
|
self, request: ChatCompletionRequest, is_multimodal: bool
|
||||||
) -> tuple[
|
) -> MessageProcessingResult:
|
||||||
str,
|
|
||||||
Union[str, List[int]],
|
|
||||||
Optional[Any],
|
|
||||||
Optional[Any],
|
|
||||||
List[str],
|
|
||||||
List[str],
|
|
||||||
Optional[Any],
|
|
||||||
]:
|
|
||||||
"""Process chat messages and apply chat template"""
|
"""Process chat messages and apply chat template"""
|
||||||
tool_call_constraint = None
|
tool_call_constraint = None
|
||||||
prompt = ""
|
|
||||||
prompt_ids = []
|
|
||||||
|
|
||||||
if not isinstance(request.messages, str):
|
# Apply chat template and its stop strings
|
||||||
# Apply chat template and its stop strings
|
tools = None
|
||||||
tools = None
|
if request.tools and request.tool_choice != "none":
|
||||||
if request.tools and request.tool_choice != "none":
|
request.skip_special_tokens = False
|
||||||
request.skip_special_tokens = False
|
if not isinstance(request.tool_choice, str):
|
||||||
if not isinstance(request.tool_choice, str):
|
tools = [
|
||||||
tools = [
|
item.function.model_dump()
|
||||||
item.function.model_dump()
|
for item in request.tools
|
||||||
for item in request.tools
|
if item.function.name == request.tool_choice.function.name
|
||||||
if item.function.name == request.tool_choice.function.name
|
]
|
||||||
]
|
|
||||||
else:
|
|
||||||
tools = [item.function.model_dump() for item in request.tools]
|
|
||||||
|
|
||||||
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
|
||||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
|
||||||
tool_call_constraint = parser.get_structure_constraint(
|
|
||||||
request.tool_choice
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use chat template
|
|
||||||
if self.template_manager.chat_template_name is None:
|
|
||||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
|
||||||
self._apply_jinja_template(request, tools, is_multimodal)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
tools = [item.function.model_dump() for item in request.tools]
|
||||||
self._apply_conversation_template(request, is_multimodal)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use raw prompt
|
|
||||||
prompt_ids = request.messages
|
|
||||||
stop = request.stop or []
|
|
||||||
image_data = None
|
|
||||||
audio_data = None
|
|
||||||
modalities = []
|
|
||||||
prompt = request.messages
|
|
||||||
|
|
||||||
return (
|
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||||
prompt,
|
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||||
prompt_ids,
|
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
|
||||||
image_data,
|
|
||||||
audio_data,
|
# Use chat template
|
||||||
modalities,
|
if self.template_manager.chat_template_name is None:
|
||||||
stop,
|
result = self._apply_jinja_template(request, tools, is_multimodal)
|
||||||
tool_call_constraint,
|
else:
|
||||||
)
|
result = self._apply_conversation_template(request, is_multimodal)
|
||||||
|
|
||||||
|
result.tool_call_constraint = tool_call_constraint
|
||||||
|
return result
|
||||||
|
|
||||||
def _apply_jinja_template(
|
def _apply_jinja_template(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
tools: Optional[List[Dict]],
|
tools: Optional[List[Dict]],
|
||||||
is_multimodal: bool,
|
is_multimodal: bool,
|
||||||
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
|
) -> MessageProcessingResult:
|
||||||
"""Apply Jinja chat template"""
|
"""Apply Jinja chat template"""
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
@@ -253,13 +214,20 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
image_data = image_data if image_data else None
|
image_data = image_data if image_data else None
|
||||||
audio_data = audio_data if audio_data else None
|
audio_data = audio_data if audio_data else None
|
||||||
modalities = modalities if modalities else []
|
modalities = modalities if modalities else []
|
||||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
return MessageProcessingResult(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
image_data=image_data,
|
||||||
|
audio_data=audio_data,
|
||||||
|
modalities=modalities,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_conversation_template(
|
def _apply_conversation_template(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
is_multimodal: bool,
|
is_multimodal: bool,
|
||||||
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
|
) -> MessageProcessingResult:
|
||||||
"""Apply conversation template"""
|
"""Apply conversation template"""
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
@@ -304,7 +272,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
if not is_multimodal:
|
if not is_multimodal:
|
||||||
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
||||||
|
|
||||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
return MessageProcessingResult(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
image_data=image_data,
|
||||||
|
audio_data=audio_data,
|
||||||
|
modalities=modalities,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
def _build_sampling_params(
|
def _build_sampling_params(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
MessageProcessingResult,
|
||||||
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
|
||||||
@@ -104,7 +107,7 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
conv_ins.stop_str = ["</s>"]
|
conv_ins.stop_str = ["</s>"]
|
||||||
conv_mock.return_value = conv_ins
|
conv_mock.return_value = conv_ins
|
||||||
|
|
||||||
proc_mock.return_value = (
|
proc_mock.return_value = MessageProcessingResult(
|
||||||
"Test prompt",
|
"Test prompt",
|
||||||
[1, 2, 3],
|
[1, 2, 3],
|
||||||
None,
|
None,
|
||||||
@@ -119,6 +122,59 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
self.assertFalse(adapted.stream)
|
self.assertFalse(adapted.stream)
|
||||||
self.assertEqual(processed, self.basic_req)
|
self.assertEqual(processed, self.basic_req)
|
||||||
|
|
||||||
|
def test_stop_str_isolation_between_requests(self):
|
||||||
|
"""Test that stop strings from one request don't affect subsequent requests.
|
||||||
|
|
||||||
|
This tests the fix for the bug where conv.stop_str was being mutated globally,
|
||||||
|
causing stop strings from one request to persist in subsequent requests.
|
||||||
|
"""
|
||||||
|
# Mock conversation template with initial stop_str
|
||||||
|
initial_stop_str = ["\n"]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||||
|
) as conv_mock:
|
||||||
|
# Create a mock conversation object that will be returned by generate_chat_conv
|
||||||
|
conv_ins = Mock()
|
||||||
|
conv_ins.get_prompt.return_value = "Test prompt"
|
||||||
|
conv_ins.image_data = None
|
||||||
|
conv_ins.audio_data = None
|
||||||
|
conv_ins.modalities = []
|
||||||
|
conv_ins.stop_str = (
|
||||||
|
initial_stop_str.copy()
|
||||||
|
) # Template's default stop strings
|
||||||
|
conv_mock.return_value = conv_ins
|
||||||
|
|
||||||
|
# First request with additional stop string
|
||||||
|
req1 = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "First request"}],
|
||||||
|
stop=["CUSTOM_STOP"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the actual _apply_conversation_template method (not mocked)
|
||||||
|
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
|
||||||
|
|
||||||
|
# Verify first request has both stop strings
|
||||||
|
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
|
||||||
|
self.assertEqual(result1.stop, expected_stop1)
|
||||||
|
|
||||||
|
# Verify the original template's stop_str wasn't mutated after first request
|
||||||
|
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||||
|
|
||||||
|
# Second request without additional stop string
|
||||||
|
req2 = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Second request"}],
|
||||||
|
# No custom stop strings
|
||||||
|
)
|
||||||
|
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
|
||||||
|
|
||||||
|
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
|
||||||
|
self.assertEqual(result2.stop, initial_stop_str)
|
||||||
|
self.assertNotIn("CUSTOM_STOP", result2.stop)
|
||||||
|
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||||
|
|
||||||
# ------------- sampling-params -------------
|
# ------------- sampling-params -------------
|
||||||
def test_sampling_param_build(self):
|
def test_sampling_param_build(self):
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
|
|||||||
Reference in New Issue
Block a user