Tiny refactor ModelConfig.from_server_args (#5219)

This commit is contained in:
fzyzcjy
2025-05-08 16:02:43 +08:00
committed by GitHub
parent 3b2680a44d
commit b6cf3532b5
6 changed files with 23 additions and 53 deletions

View File

@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__)
@@ -210,6 +211,21 @@ class ModelConfig:
self.hf_eos_token_id = self.get_hf_eos_token_id()
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
@staticmethod
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
**kwargs,
)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""