misc: Improvement to serving_chat.py and add more ut (#7489)

This commit is contained in:
Chang Su
2025-06-24 17:19:51 -07:00
committed by GitHub
parent 3562256bb2
commit 112b496a6c
3 changed files with 139 additions and 80 deletions

View File

@@ -14,7 +14,8 @@
"""Pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from pydantic import (
BaseModel,
@@ -587,3 +588,30 @@ OpenAIServingRequest = Union[
ScoringRequest,
V1RerankReqInput,
]
@dataclass
class MessageProcessingResult:
"""Result of processing chat messages and applying templates.
This dataclass encapsulates all the outputs from message processing including
prompt generation, multimodal data extraction, and constraint preparation.
Used internally by OpenAIServingChat to pass processed data between methods.
Args:
prompt: The final text prompt after applying chat template
prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
image_data: Extracted image data from messages, if any
audio_data: Extracted audio data from messages, if any
modalities: List of modality types present in the messages
stop: Combined stop strings from template and request
tool_call_constraint: Optional constraint for structured tool calls
"""
prompt: str
prompt_ids: Union[str, List[int]]
image_data: Optional[Any]
audio_data: Optional[Any]
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None

View File

@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
FunctionResponse,
LogProbs,
MessageProcessingResult,
ToolCall,
TopLogprob,
)
@@ -62,41 +63,33 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
# Process messages and apply chat template
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = self._process_messages(request, is_multimodal)
processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint
request, processed_messages.stop, processed_messages.tool_call_constraint
)
# Handle single vs multiple requests
if is_multimodal:
prompt_kwargs = {"text": prompt}
prompt_kwargs = {"text": processed_messages.prompt}
else:
if isinstance(prompt_ids, str):
prompt_kwargs = {"text": prompt_ids}
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
else:
prompt_kwargs = {"input_ids": prompt_ids}
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data,
audio_data=audio_data,
image_data=processed_messages.image_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
modalities=modalities,
modalities=processed_messages.modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
@@ -108,74 +101,42 @@ class OpenAIServingChat(OpenAIServingBase):
def _process_messages(
self, request: ChatCompletionRequest, is_multimodal: bool
) -> tuple[
str,
Union[str, List[int]],
Optional[Any],
Optional[Any],
List[str],
List[str],
Optional[Any],
]:
) -> MessageProcessingResult:
"""Process chat messages and apply chat template"""
tool_call_constraint = None
prompt = ""
prompt_ids = []
if not isinstance(request.messages, str):
# Apply chat template and its stop strings
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
# Use chat template
if self.template_manager.chat_template_name is None:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal)
)
# Apply chat template and its stop strings
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request, is_multimodal)
)
else:
# Use raw prompt
prompt_ids = request.messages
stop = request.stop or []
image_data = None
audio_data = None
modalities = []
prompt = request.messages
tools = [item.function.model_dump() for item in request.tools]
return (
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
)
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
# Use chat template
if self.template_manager.chat_template_name is None:
result = self._apply_jinja_template(request, tools, is_multimodal)
else:
result = self._apply_conversation_template(request, is_multimodal)
result.tool_call_constraint = tool_call_constraint
return result
def _apply_jinja_template(
self,
request: ChatCompletionRequest,
tools: Optional[List[Dict]],
is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
) -> MessageProcessingResult:
"""Apply Jinja chat template"""
prompt = ""
prompt_ids = []
@@ -253,13 +214,20 @@ class OpenAIServingChat(OpenAIServingBase):
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)
def _apply_conversation_template(
self,
request: ChatCompletionRequest,
is_multimodal: bool,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
) -> MessageProcessingResult:
"""Apply conversation template"""
prompt = ""
prompt_ids = []
@@ -304,7 +272,14 @@ class OpenAIServingChat(OpenAIServingBase):
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
return prompt, prompt_ids, image_data, audio_data, modalities, stop
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)
def _build_sampling_params(
self,