misc: Improvement to serving_chat.py and add more ut (#7489)

This commit is contained in:
Chang Su
2025-06-24 17:19:51 -07:00
committed by GitHub
parent 3562256bb2
commit 112b496a6c
3 changed files with 139 additions and 80 deletions

View File

@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
FunctionResponse,
LogProbs,
MessageProcessingResult,
ToolCall,
TopLogprob,
)
@@ -62,41 +63,33 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
# Process messages and apply chat template
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = self._process_messages(request, is_multimodal)
processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint
request, processed_messages.stop, processed_messages.tool_call_constraint
)
# Handle single vs multiple requests
if is_multimodal:
prompt_kwargs = {"text": prompt}
prompt_kwargs = {"text": processed_messages.prompt}
else:
if isinstance(prompt_ids, str):
prompt_kwargs = {"text": prompt_ids}
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
else:
prompt_kwargs = {"input_ids": prompt_ids}
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data,
audio_data=audio_data,
image_data=processed_messages.image_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
modalities=modalities,
modalities=processed_messages.modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
@@ -108,74 +101,42 @@ class OpenAIServingChat(OpenAIServingBase):
def _process_messages(
self, request: ChatCompletionRequest, is_multimodal: bool
) -> tuple[
str,
Union[str, List[int]],
Optional[Any],
Optional[Any],
List[str],
List[str],
Optional[Any],
]:
) -> MessageProcessingResult:
"""Process chat messages and apply chat template"""
tool_call_constraint = None
prompt = ""
prompt_ids = []
if not isinstance(request.messages, str):
# Apply chat template and its stop strings
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
# Use chat template
if self.template_manager.chat_template_name is None:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal)
)
# Apply chat template and its stop strings
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request, is_multimodal)
)
else:
# Use raw prompt
prompt_ids = request.messages
stop = request.stop or []
image_data = None
audio_data = None
modalities = []
prompt = request.messages
tools = [item.function.model_dump() for item in request.tools]
return (
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
)
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
# Use chat template
if self.template_manager.chat_template_name is None:
result = self._apply_jinja_template(request, tools, is_multimodal)
else:
result = self._apply_conversation_template(request, is_multimodal)
result.tool_call_constraint = tool_call_constraint
return result
def _apply_jinja_template(
self,
request: ChatCompletionRequest,
tools: Optional[List[Dict]],
is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
) -> MessageProcessingResult:
"""Apply Jinja chat template"""
prompt = ""
prompt_ids = []
@@ -253,13 +214,20 @@ class OpenAIServingChat(OpenAIServingBase):
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)
def _apply_conversation_template(
self,
request: ChatCompletionRequest,
is_multimodal: bool,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
) -> MessageProcessingResult:
"""Apply conversation template"""
prompt = ""
prompt_ids = []
@@ -304,7 +272,14 @@ class OpenAIServingChat(OpenAIServingBase):
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
return prompt, prompt_ids, image_data, audio_data, modalities, stop
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)
def _build_sampling_params(
self,