misc: Improvement to serving_chat.py and add more ut (#7489)
This commit is contained in:
@@ -14,7 +14,8 @@
|
||||
"""Pydantic models for OpenAI API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -587,3 +588,30 @@ OpenAIServingRequest = Union[
|
||||
ScoringRequest,
|
||||
V1RerankReqInput,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageProcessingResult:
|
||||
"""Result of processing chat messages and applying templates.
|
||||
|
||||
This dataclass encapsulates all the outputs from message processing including
|
||||
prompt generation, multimodal data extraction, and constraint preparation.
|
||||
Used internally by OpenAIServingChat to pass processed data between methods.
|
||||
|
||||
Args:
|
||||
prompt: The final text prompt after applying chat template
|
||||
prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
|
||||
image_data: Extracted image data from messages, if any
|
||||
audio_data: Extracted audio data from messages, if any
|
||||
modalities: List of modality types present in the messages
|
||||
stop: Combined stop strings from template and request
|
||||
tool_call_constraint: Optional constraint for structured tool calls
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
prompt_ids: Union[str, List[int]]
|
||||
image_data: Optional[Any]
|
||||
audio_data: Optional[Any]
|
||||
modalities: List[str]
|
||||
stop: List[str]
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
FunctionResponse,
|
||||
LogProbs,
|
||||
MessageProcessingResult,
|
||||
ToolCall,
|
||||
TopLogprob,
|
||||
)
|
||||
@@ -62,41 +63,33 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||
|
||||
# Process messages and apply chat template
|
||||
(
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
) = self._process_messages(request, is_multimodal)
|
||||
processed_messages = self._process_messages(request, is_multimodal)
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request, stop, tool_call_constraint
|
||||
request, processed_messages.stop, processed_messages.tool_call_constraint
|
||||
)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": prompt}
|
||||
prompt_kwargs = {"text": processed_messages.prompt}
|
||||
else:
|
||||
if isinstance(prompt_ids, str):
|
||||
prompt_kwargs = {"text": prompt_ids}
|
||||
if isinstance(processed_messages.prompt_ids, str):
|
||||
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt_ids}
|
||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
image_data=processed_messages.image_data,
|
||||
audio_data=processed_messages.audio_data,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=request.logprobs,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=request.top_logprobs or 0,
|
||||
stream=request.stream,
|
||||
return_text_in_logprobs=True,
|
||||
modalities=modalities,
|
||||
modalities=processed_messages.modalities,
|
||||
lora_path=request.lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
@@ -108,74 +101,42 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
def _process_messages(
|
||||
self, request: ChatCompletionRequest, is_multimodal: bool
|
||||
) -> tuple[
|
||||
str,
|
||||
Union[str, List[int]],
|
||||
Optional[Any],
|
||||
Optional[Any],
|
||||
List[str],
|
||||
List[str],
|
||||
Optional[Any],
|
||||
]:
|
||||
) -> MessageProcessingResult:
|
||||
"""Process chat messages and apply chat template"""
|
||||
tool_call_constraint = None
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
for item in request.tools
|
||||
if item.function.name == request.tool_choice.function.name
|
||||
]
|
||||
else:
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
tool_call_constraint = parser.get_structure_constraint(
|
||||
request.tool_choice
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
if self.template_manager.chat_template_name is None:
|
||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||
self._apply_jinja_template(request, tools, is_multimodal)
|
||||
)
|
||||
# Apply chat template and its stop strings
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
for item in request.tools
|
||||
if item.function.name == request.tool_choice.function.name
|
||||
]
|
||||
else:
|
||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||
self._apply_conversation_template(request, is_multimodal)
|
||||
)
|
||||
else:
|
||||
# Use raw prompt
|
||||
prompt_ids = request.messages
|
||||
stop = request.stop or []
|
||||
image_data = None
|
||||
audio_data = None
|
||||
modalities = []
|
||||
prompt = request.messages
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
return (
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
)
|
||||
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
|
||||
|
||||
# Use chat template
|
||||
if self.template_manager.chat_template_name is None:
|
||||
result = self._apply_jinja_template(request, tools, is_multimodal)
|
||||
else:
|
||||
result = self._apply_conversation_template(request, is_multimodal)
|
||||
|
||||
result.tool_call_constraint = tool_call_constraint
|
||||
return result
|
||||
|
||||
def _apply_jinja_template(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
tools: Optional[List[Dict]],
|
||||
is_multimodal: bool,
|
||||
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
|
||||
) -> MessageProcessingResult:
|
||||
"""Apply Jinja chat template"""
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
@@ -253,13 +214,20 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
image_data = image_data if image_data else None
|
||||
audio_data = audio_data if audio_data else None
|
||||
modalities = modalities if modalities else []
|
||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||
return MessageProcessingResult(
|
||||
prompt=prompt,
|
||||
prompt_ids=prompt_ids,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
modalities=modalities,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _apply_conversation_template(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
is_multimodal: bool,
|
||||
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
|
||||
) -> MessageProcessingResult:
|
||||
"""Apply conversation template"""
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
@@ -304,7 +272,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if not is_multimodal:
|
||||
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
||||
|
||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||
return MessageProcessingResult(
|
||||
prompt=prompt,
|
||||
prompt_ids=prompt_ids,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
modalities=modalities,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _build_sampling_params(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user