From 7ba3de0e921c31e7dafa2b06228a180c2451c1cd Mon Sep 17 00:00:00 2001 From: Chang Su Date: Tue, 7 Oct 2025 17:36:05 -0700 Subject: [PATCH] [oai serving chat] Add argument `--sampling-defaults` and fix `ChatCompletionRequest` defaults (#11304) --- python/sglang/srt/configs/model_config.py | 37 ++++- .../sglang/srt/entrypoints/openai/protocol.py | 152 ++++++++++++++---- .../srt/entrypoints/openai/serving_chat.py | 84 ++-------- python/sglang/srt/server_args.py | 11 ++ test/srt/openai_server/basic/test_protocol.py | 18 ++- .../openai_server/basic/test_serving_chat.py | 22 --- 6 files changed, 198 insertions(+), 126 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index b89b8de68..ce54ef802 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -17,7 +17,7 @@ import logging import math import os from enum import Enum, IntEnum, auto -from typing import Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union import torch from transformers import PretrainedConfig @@ -90,6 +90,7 @@ class ModelConfig: is_draft_model: bool = False, hybrid_kvcache_ratio: Optional[float] = None, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, + sampling_defaults: str = "openai", ) -> None: # Parse args self.model_path = model_path @@ -98,6 +99,7 @@ class ModelConfig: self.modelopt_quant = modelopt_quant self.is_draft_model = is_draft_model self.model_impl = model_impl + self.sampling_defaults = sampling_defaults # Get hf config self._maybe_pull_model_tokenizer_from_remote() @@ -214,6 +216,7 @@ class ModelConfig: modelopt_quant=server_args.modelopt_quant, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, model_impl=server_args.model_impl, + sampling_defaults=server_args.sampling_defaults, **kwargs, ) @@ -659,6 +662,38 @@ class ModelConfig: eos_ids = eos_ids | generation_eos_ids return eos_ids + def get_default_sampling_params(self) -> dict[str, Any]: + """ + Get default sampling parameters from the model's generation config. + + This method returns non-default sampling parameters from the model's + generation_config.json when sampling_defaults is set to "model". + + Returns: + A dictionary containing the non-default sampling parameters. + """ + if self.sampling_defaults != "model": + return {} + + if self.hf_generation_config is None: + return {} + + config = self.hf_generation_config.to_dict() + + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + ] + + default_sampling_params = { + p: config.get(p) for p in available_params if config.get(p) is not None + } + + return default_sampling_params + def _maybe_pull_model_tokenizer_from_remote(self) -> None: """ Pull the model config files to a temporary diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 1ce282927..3acb791aa 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -13,6 +13,7 @@ # ============================================================================== """Pydantic models for OpenAI API protocol""" +import logging import time import uuid from dataclasses import dataclass @@ -37,6 +38,10 @@ from pydantic import ( ) from typing_extensions import Literal +from sglang.utils import convert_json_schema_to_str + +logger = logging.getLogger(__name__) + DEFAULT_MODEL_NAME = "default" @@ -445,8 +450,8 @@ class ChatCompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: bool = False stream_options: Optional[StreamOptions] = None - temperature: float = 0.7 - top_p: float = 1.0 + temperature: Optional[float] = None + top_p: Optional[float] = None user: Optional[str] = None tools: Optional[List[Tool]] = Field(default=None, examples=[None]) tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( @@ -461,6 +466,47 @@ class ChatCompletionRequest(BaseModel): "Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.", ) + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + top_k: Optional[int] = None + min_p: Optional[float] = None + min_tokens: int = 0 + regex: Optional[str] = None + ebnf: Optional[str] = None + repetition_penalty: Optional[float] = None + stop_token_ids: Optional[List[int]] = None + no_stop_trim: bool = False + ignore_eos: bool = False + continue_final_message: bool = False + skip_special_tokens: bool = True + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None + separate_reasoning: bool = True + stream_reasoning: bool = True + chat_template_kwargs: Optional[Dict] = None + + # For request id + rid: Optional[Union[List[str], str]] = None + # Extra key for classifying the request (e.g. cache_salt) + extra_key: Optional[Union[List[str], str]] = None + # Cache salt for request caching + cache_salt: Optional[Union[List[str], str]] = None + # Priority for the request + priority: Optional[int] = None + + # For PD disaggregation + bootstrap_host: Optional[Union[List[str], str]] = None + bootstrap_port: Optional[Union[List[Optional[int]], int]] = None + bootstrap_room: Optional[Union[List[int], int]] = None + + # OpenAI/SGLang default sampling parameters + _DEFAULT_SAMPLING_PARAMS = { + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + "repetition_penalty": 1.0, + } + @model_validator(mode="before") @classmethod def set_tool_choice_default(cls, values): @@ -531,37 +577,81 @@ class ChatCompletionRequest(BaseModel): return values - # Extra parameters for SRT backend only and will be ignored by OpenAI models. - top_k: int = -1 - min_p: float = 0.0 - min_tokens: int = 0 - regex: Optional[str] = None - ebnf: Optional[str] = None - repetition_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = None - no_stop_trim: bool = False - ignore_eos: bool = False - continue_final_message: bool = False - skip_special_tokens: bool = True - lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - session_params: Optional[Dict] = None - separate_reasoning: bool = True - stream_reasoning: bool = True - chat_template_kwargs: Optional[Dict] = None + def to_sampling_params( + self, + stop: List[str], + model_generation_config: Dict[str, Any], + tool_call_constraint: Optional[Any] = None, + ) -> Dict[str, Any]: + """ + Convert request to sampling parameters. + Priority: user value > model generation_config > OpenAI defaults + """ - # For request id - rid: Optional[Union[List[str], str]] = None - # Extra key for classifying the request (e.g. cache_salt) - extra_key: Optional[Union[List[str], str]] = None - # Cache salt for request caching - cache_salt: Optional[Union[List[str], str]] = None - # Priority for the request - priority: Optional[int] = None + def get_param(param_name: str): + value = getattr(self, param_name) + if value is None: + return model_generation_config.get( + param_name, self._DEFAULT_SAMPLING_PARAMS[param_name] + ) + return value - # For PD disaggregation - bootstrap_host: Optional[Union[List[str], str]] = None - bootstrap_port: Optional[Union[List[Optional[int]], int]] = None - bootstrap_room: Optional[Union[List[int], int]] = None + sampling_params = { + "temperature": get_param("temperature"), + "max_new_tokens": self.max_tokens or self.max_completion_tokens, + "min_new_tokens": self.min_tokens, + "stop": stop, + "stop_token_ids": self.stop_token_ids, + "top_p": get_param("top_p"), + "top_k": get_param("top_k"), + "min_p": get_param("min_p"), + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "repetition_penalty": get_param("repetition_penalty"), + "regex": self.regex, + "ebnf": self.ebnf, + "n": self.n, + "no_stop_trim": self.no_stop_trim, + "ignore_eos": self.ignore_eos, + "skip_special_tokens": self.skip_special_tokens, + "logit_bias": self.logit_bias, + } + + if self.response_format and self.response_format.type == "json_schema": + sampling_params["json_schema"] = convert_json_schema_to_str( + self.response_format.json_schema.schema_ + ) + elif self.response_format and self.response_format.type == "json_object": + sampling_params["json_schema"] = '{"type": "object"}' + elif self.response_format and self.response_format.type == "structural_tag": + sampling_params["structural_tag"] = convert_json_schema_to_str( + self.response_format.model_dump(by_alias=True) + ) + + # Check if there are already existing output constraints + has_existing_constraints = ( + sampling_params.get("regex") + or sampling_params.get("ebnf") + or sampling_params.get("structural_tag") + or sampling_params.get("json_schema") + ) + + if tool_call_constraint and has_existing_constraints: + logger.warning("Constrained decoding is not compatible with tool calls.") + elif tool_call_constraint: + constraint_type, constraint_value = tool_call_constraint + if constraint_type == "structural_tag": + sampling_params[constraint_type] = convert_json_schema_to_str( + constraint_value.model_dump(by_alias=True) + ) + elif constraint_type == "json_schema": + sampling_params[constraint_type] = convert_json_schema_to_str( + constraint_value + ) + else: + sampling_params[constraint_type] = constraint_value + + return sampling_params class ChatMessage(BaseModel): diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 13e40a19c..08a2bf20d 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -44,7 +44,6 @@ from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.parser.conversation import generate_chat_conv from sglang.srt.parser.jinja_template_utils import process_content_for_template_format from sglang.srt.parser.reasoning_parser import ReasoningParser -from sglang.utils import convert_json_schema_to_str if TYPE_CHECKING: from sglang.srt.managers.template_manager import TemplateManager @@ -66,6 +65,15 @@ class OpenAIServingChat(OpenAIServingBase): self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser + # Get default sampling parameters from model's generation config + self.default_sampling_params = ( + self.tokenizer_manager.model_config.get_default_sampling_params() + ) + if self.default_sampling_params: + logger.info( + f"Using default chat sampling params from model generation config: {self.default_sampling_params}", + ) + def _request_id_prefix(self) -> str: return "chatcmpl-" @@ -137,10 +145,10 @@ class OpenAIServingChat(OpenAIServingBase): processed_messages = self._process_messages(request, is_multimodal) # Build sampling parameters - sampling_params = self._build_sampling_params( - request, - processed_messages.stop, - processed_messages.tool_call_constraint, + sampling_params = request.to_sampling_params( + stop=processed_messages.stop, + model_generation_config=self.default_sampling_params, + tool_call_constraint=processed_messages.tool_call_constraint, ) # Handle single vs multiple requests @@ -410,72 +418,6 @@ class OpenAIServingChat(OpenAIServingBase): stop=stop, ) - def _build_sampling_params( - self, - request: ChatCompletionRequest, - stop: List[str], - tool_call_constraint: Optional[Any], - ) -> Dict[str, Any]: - """Build sampling parameters for the request""" - - sampling_params = { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens or request.max_completion_tokens, - "min_new_tokens": request.min_tokens, - "stop": stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "top_k": request.top_k, - "min_p": request.min_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "ebnf": request.ebnf, - "n": request.n, - "no_stop_trim": request.no_stop_trim, - "ignore_eos": request.ignore_eos, - "skip_special_tokens": request.skip_special_tokens, - "logit_bias": request.logit_bias, - } - - if request.response_format and request.response_format.type == "json_schema": - sampling_params["json_schema"] = convert_json_schema_to_str( - request.response_format.json_schema.schema_ - ) - elif request.response_format and request.response_format.type == "json_object": - sampling_params["json_schema"] = '{"type": "object"}' - elif ( - request.response_format and request.response_format.type == "structural_tag" - ): - sampling_params["structural_tag"] = convert_json_schema_to_str( - request.response_format.model_dump(by_alias=True) - ) - - # Check if there are already existing output constraints - has_existing_constraints = ( - sampling_params.get("regex") - or sampling_params.get("ebnf") - or sampling_params.get("structural_tag") - or sampling_params.get("json_schema") - ) - - if tool_call_constraint and has_existing_constraints: - logger.warning("Constrained decoding is not compatible with tool calls.") - elif tool_call_constraint: - constraint_type, constraint_value = tool_call_constraint - if constraint_type == "structural_tag": - sampling_params[constraint_type] = convert_json_schema_to_str( - constraint_value.model_dump(by_alias=True) - ) - elif constraint_type == "json_schema": - sampling_params[constraint_type] = convert_json_schema_to_str( - constraint_value - ) - else: - sampling_params[constraint_type] = constraint_value - return sampling_params - async def _handle_streaming_request( self, adapted_request: GenerateReqInput, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 401875ff7..8c955aabc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -252,6 +252,7 @@ class ServerArgs: reasoning_parser: Optional[str] = None tool_call_parser: Optional[str] = None tool_server: Optional[str] = None + sampling_defaults: str = "model" # Data parallelism dp_size: int = 1 @@ -1872,6 +1873,16 @@ class ServerArgs: default=ServerArgs.tool_call_parser, help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.", ) + parser.add_argument( + "--sampling-defaults", + type=str, + choices=["openai", "model"], + default=ServerArgs.sampling_defaults, + help="Where to get default sampling parameters. " + "'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). " + "'model' uses the model's generation_config.json to get the recommended " + "sampling parameters if available. Default is 'model'.", + ) parser.add_argument( "--tool-server", type=str, diff --git a/test/srt/openai_server/basic/test_protocol.py b/test/srt/openai_server/basic/test_protocol.py index fcaa9770b..fbf1e3971 100644 --- a/test/srt/openai_server/basic/test_protocol.py +++ b/test/srt/openai_server/basic/test_protocol.py @@ -150,10 +150,26 @@ class TestChatCompletionRequest(unittest.TestCase): self.assertEqual(len(request.messages), 1) self.assertEqual(request.messages[0].role, "user") self.assertEqual(request.messages[0].content, "Hello") - self.assertEqual(request.temperature, 0.7) # default + self.assertEqual(request.temperature, None) # default self.assertFalse(request.stream) # default self.assertEqual(request.tool_choice, "none") # default when no tools + def test_sampling_param_build(self): + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi"}], + temperature=0.8, + max_tokens=150, + min_tokens=5, + top_p=0.9, + stop=[""], + ) + params = req.to_sampling_params([""], {}, None) + self.assertEqual(params["temperature"], 0.8) + self.assertEqual(params["max_new_tokens"], 150) + self.assertEqual(params["min_new_tokens"], 5) + self.assertEqual(params["stop"], [""]) + def test_chat_completion_tool_choice_validation(self): """Test tool choice validation logic""" messages = [{"role": "user", "content": "Hello"}] diff --git a/test/srt/openai_server/basic/test_serving_chat.py b/test/srt/openai_server/basic/test_serving_chat.py index a4e55385f..fbbbcccdd 100644 --- a/test/srt/openai_server/basic/test_serving_chat.py +++ b/test/srt/openai_server/basic/test_serving_chat.py @@ -177,28 +177,6 @@ class ServingChatTestCase(unittest.TestCase): self.assertNotIn("CUSTOM_STOP", result2.stop) self.assertEqual(conv_ins.stop_str, initial_stop_str) - # ------------- sampling-params ------------- - def test_sampling_param_build(self): - req = ChatCompletionRequest( - model="x", - messages=[{"role": "user", "content": "Hi"}], - temperature=0.8, - max_tokens=150, - min_tokens=5, - top_p=0.9, - stop=[""], - ) - with patch.object( - self.chat, - "_process_messages", - return_value=("Prompt", [1], None, None, [], [""], None), - ): - params = self.chat._build_sampling_params(req, [""], None) - self.assertEqual(params["temperature"], 0.8) - self.assertEqual(params["max_new_tokens"], 150) - self.assertEqual(params["min_new_tokens"], 5) - self.assertEqual(params["stop"], [""]) - async def test_unstreamed_tool_args_completion(self): """Test that remaining tool call arguments are sent when generation finishes."""