[oai serving chat] Add argument --sampling-defaults and fix ChatCompletionRequest defaults (#11304)

This commit is contained in:
Chang Su
2025-10-07 17:36:05 -07:00
committed by GitHub
parent fde9b96392
commit 7ba3de0e92
6 changed files with 198 additions and 126 deletions

View File

@@ -17,7 +17,7 @@ import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
@@ -90,6 +90,7 @@ class ModelConfig:
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
sampling_defaults: str = "openai",
) -> None:
# Parse args
self.model_path = model_path
@@ -98,6 +99,7 @@ class ModelConfig:
self.modelopt_quant = modelopt_quant
self.is_draft_model = is_draft_model
self.model_impl = model_impl
self.sampling_defaults = sampling_defaults
# Get hf config
self._maybe_pull_model_tokenizer_from_remote()
@@ -214,6 +216,7 @@ class ModelConfig:
modelopt_quant=server_args.modelopt_quant,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
sampling_defaults=server_args.sampling_defaults,
**kwargs,
)
@@ -659,6 +662,38 @@ class ModelConfig:
eos_ids = eos_ids | generation_eos_ids
return eos_ids
def get_default_sampling_params(self) -> dict[str, Any]:
"""
Get default sampling parameters from the model's generation config.
This method returns non-default sampling parameters from the model's
generation_config.json when sampling_defaults is set to "model".
Returns:
A dictionary containing the non-default sampling parameters.
"""
if self.sampling_defaults != "model":
return {}
if self.hf_generation_config is None:
return {}
config = self.hf_generation_config.to_dict()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
default_sampling_params = {
p: config.get(p) for p in available_params if config.get(p) is not None
}
return default_sampling_params
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary