[oai serving chat] Add argument --sampling-defaults and fix ChatCompletionRequest defaults (#11304)
This commit is contained in:
@@ -17,7 +17,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -90,6 +90,7 @@ class ModelConfig:
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
sampling_defaults: str = "openai",
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
@@ -98,6 +99,7 @@ class ModelConfig:
|
||||
self.modelopt_quant = modelopt_quant
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
self.sampling_defaults = sampling_defaults
|
||||
|
||||
# Get hf config
|
||||
self._maybe_pull_model_tokenizer_from_remote()
|
||||
@@ -214,6 +216,7 @@ class ModelConfig:
|
||||
modelopt_quant=server_args.modelopt_quant,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
sampling_defaults=server_args.sampling_defaults,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -659,6 +662,38 @@ class ModelConfig:
|
||||
eos_ids = eos_ids | generation_eos_ids
|
||||
return eos_ids
|
||||
|
||||
def get_default_sampling_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get default sampling parameters from the model's generation config.
|
||||
|
||||
This method returns non-default sampling parameters from the model's
|
||||
generation_config.json when sampling_defaults is set to "model".
|
||||
|
||||
Returns:
|
||||
A dictionary containing the non-default sampling parameters.
|
||||
"""
|
||||
if self.sampling_defaults != "model":
|
||||
return {}
|
||||
|
||||
if self.hf_generation_config is None:
|
||||
return {}
|
||||
|
||||
config = self.hf_generation_config.to_dict()
|
||||
|
||||
available_params = [
|
||||
"repetition_penalty",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"min_p",
|
||||
]
|
||||
|
||||
default_sampling_params = {
|
||||
p: config.get(p) for p in available_params if config.get(p) is not None
|
||||
}
|
||||
|
||||
return default_sampling_params
|
||||
|
||||
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
"""
|
||||
Pull the model config files to a temporary
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Pydantic models for OpenAI API protocol"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
@@ -37,6 +38,10 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL_NAME = "default"
|
||||
|
||||
|
||||
@@ -445,8 +450,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: float = 0.7
|
||||
top_p: float = 1.0
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
||||
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
||||
@@ -461,6 +466,47 @@ class ChatCompletionRequest(BaseModel):
|
||||
"Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.",
|
||||
)
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: Optional[int] = None
|
||||
min_p: Optional[float] = None
|
||||
min_tokens: int = 0
|
||||
regex: Optional[str] = None
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
continue_final_message: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
separate_reasoning: bool = True
|
||||
stream_reasoning: bool = True
|
||||
chat_template_kwargs: Optional[Dict] = None
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
|
||||
# OpenAI/SGLang default sampling parameters
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": -1,
|
||||
"min_p": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_tool_choice_default(cls, values):
|
||||
@@ -531,37 +577,81 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
return values
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
min_tokens: int = 0
|
||||
regex: Optional[str] = None
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
continue_final_message: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
separate_reasoning: bool = True
|
||||
stream_reasoning: bool = True
|
||||
chat_template_kwargs: Optional[Dict] = None
|
||||
def to_sampling_params(
|
||||
self,
|
||||
stop: List[str],
|
||||
model_generation_config: Dict[str, Any],
|
||||
tool_call_constraint: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert request to sampling parameters.
|
||||
Priority: user value > model generation_config > OpenAI defaults
|
||||
"""
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
def get_param(param_name: str):
|
||||
value = getattr(self, param_name)
|
||||
if value is None:
|
||||
return model_generation_config.get(
|
||||
param_name, self._DEFAULT_SAMPLING_PARAMS[param_name]
|
||||
)
|
||||
return value
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
sampling_params = {
|
||||
"temperature": get_param("temperature"),
|
||||
"max_new_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"min_new_tokens": self.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"top_p": get_param("top_p"),
|
||||
"top_k": get_param("top_k"),
|
||||
"min_p": get_param("min_p"),
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"repetition_penalty": get_param("repetition_penalty"),
|
||||
"regex": self.regex,
|
||||
"ebnf": self.ebnf,
|
||||
"n": self.n,
|
||||
"no_stop_trim": self.no_stop_trim,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"skip_special_tokens": self.skip_special_tokens,
|
||||
"logit_bias": self.logit_bias,
|
||||
}
|
||||
|
||||
if self.response_format and self.response_format.type == "json_schema":
|
||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||
self.response_format.json_schema.schema_
|
||||
)
|
||||
elif self.response_format and self.response_format.type == "json_object":
|
||||
sampling_params["json_schema"] = '{"type": "object"}'
|
||||
elif self.response_format and self.response_format.type == "structural_tag":
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
self.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
# Check if there are already existing output constraints
|
||||
has_existing_constraints = (
|
||||
sampling_params.get("regex")
|
||||
or sampling_params.get("ebnf")
|
||||
or sampling_params.get("structural_tag")
|
||||
or sampling_params.get("json_schema")
|
||||
)
|
||||
|
||||
if tool_call_constraint and has_existing_constraints:
|
||||
logger.warning("Constrained decoding is not compatible with tool calls.")
|
||||
elif tool_call_constraint:
|
||||
constraint_type, constraint_value = tool_call_constraint
|
||||
if constraint_type == "structural_tag":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value.model_dump(by_alias=True)
|
||||
)
|
||||
elif constraint_type == "json_schema":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value
|
||||
)
|
||||
else:
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
|
||||
return sampling_params
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
||||
@@ -44,7 +44,6 @@ from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.parser.conversation import generate_chat_conv
|
||||
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
@@ -66,6 +65,15 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||
|
||||
# Get default sampling parameters from model's generation config
|
||||
self.default_sampling_params = (
|
||||
self.tokenizer_manager.model_config.get_default_sampling_params()
|
||||
)
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
f"Using default chat sampling params from model generation config: {self.default_sampling_params}",
|
||||
)
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "chatcmpl-"
|
||||
|
||||
@@ -137,10 +145,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
processed_messages = self._process_messages(request, is_multimodal)
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request,
|
||||
processed_messages.stop,
|
||||
processed_messages.tool_call_constraint,
|
||||
sampling_params = request.to_sampling_params(
|
||||
stop=processed_messages.stop,
|
||||
model_generation_config=self.default_sampling_params,
|
||||
tool_call_constraint=processed_messages.tool_call_constraint,
|
||||
)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
@@ -410,72 +418,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _build_sampling_params(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
stop: List[str],
|
||||
tool_call_constraint: Optional[Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Build sampling parameters for the request"""
|
||||
|
||||
sampling_params = {
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"ebnf": request.ebnf,
|
||||
"n": request.n,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
|
||||
if request.response_format and request.response_format.type == "json_schema":
|
||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||
request.response_format.json_schema.schema_
|
||||
)
|
||||
elif request.response_format and request.response_format.type == "json_object":
|
||||
sampling_params["json_schema"] = '{"type": "object"}'
|
||||
elif (
|
||||
request.response_format and request.response_format.type == "structural_tag"
|
||||
):
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
# Check if there are already existing output constraints
|
||||
has_existing_constraints = (
|
||||
sampling_params.get("regex")
|
||||
or sampling_params.get("ebnf")
|
||||
or sampling_params.get("structural_tag")
|
||||
or sampling_params.get("json_schema")
|
||||
)
|
||||
|
||||
if tool_call_constraint and has_existing_constraints:
|
||||
logger.warning("Constrained decoding is not compatible with tool calls.")
|
||||
elif tool_call_constraint:
|
||||
constraint_type, constraint_value = tool_call_constraint
|
||||
if constraint_type == "structural_tag":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value.model_dump(by_alias=True)
|
||||
)
|
||||
elif constraint_type == "json_schema":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value
|
||||
)
|
||||
else:
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
return sampling_params
|
||||
|
||||
async def _handle_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
|
||||
@@ -252,6 +252,7 @@ class ServerArgs:
|
||||
reasoning_parser: Optional[str] = None
|
||||
tool_call_parser: Optional[str] = None
|
||||
tool_server: Optional[str] = None
|
||||
sampling_defaults: str = "model"
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
@@ -1872,6 +1873,16 @@ class ServerArgs:
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-defaults",
|
||||
type=str,
|
||||
choices=["openai", "model"],
|
||||
default=ServerArgs.sampling_defaults,
|
||||
help="Where to get default sampling parameters. "
|
||||
"'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). "
|
||||
"'model' uses the model's generation_config.json to get the recommended "
|
||||
"sampling parameters if available. Default is 'model'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user