[oai serving chat] Add argument --sampling-defaults and fix ChatCompletionRequest defaults (#11304)
This commit is contained in:
@@ -17,7 +17,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -90,6 +90,7 @@ class ModelConfig:
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
sampling_defaults: str = "openai",
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
@@ -98,6 +99,7 @@ class ModelConfig:
|
||||
self.modelopt_quant = modelopt_quant
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
self.sampling_defaults = sampling_defaults
|
||||
|
||||
# Get hf config
|
||||
self._maybe_pull_model_tokenizer_from_remote()
|
||||
@@ -214,6 +216,7 @@ class ModelConfig:
|
||||
modelopt_quant=server_args.modelopt_quant,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
sampling_defaults=server_args.sampling_defaults,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -659,6 +662,38 @@ class ModelConfig:
|
||||
eos_ids = eos_ids | generation_eos_ids
|
||||
return eos_ids
|
||||
|
||||
def get_default_sampling_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get default sampling parameters from the model's generation config.
|
||||
|
||||
This method returns non-default sampling parameters from the model's
|
||||
generation_config.json when sampling_defaults is set to "model".
|
||||
|
||||
Returns:
|
||||
A dictionary containing the non-default sampling parameters.
|
||||
"""
|
||||
if self.sampling_defaults != "model":
|
||||
return {}
|
||||
|
||||
if self.hf_generation_config is None:
|
||||
return {}
|
||||
|
||||
config = self.hf_generation_config.to_dict()
|
||||
|
||||
available_params = [
|
||||
"repetition_penalty",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"min_p",
|
||||
]
|
||||
|
||||
default_sampling_params = {
|
||||
p: config.get(p) for p in available_params if config.get(p) is not None
|
||||
}
|
||||
|
||||
return default_sampling_params
|
||||
|
||||
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
"""
|
||||
Pull the model config files to a temporary
|
||||
|
||||
Reference in New Issue
Block a user