From b6cf3532b504081ea1d5b52244c0e889b7544856 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 8 May 2025 16:02:43 +0800 Subject: [PATCH] Tiny refactor ModelConfig.from_server_args (#5219) --- python/sglang/bench_one_batch.py | 12 +----------- python/sglang/srt/configs/model_config.py | 16 ++++++++++++++++ python/sglang/srt/managers/scheduler.py | 12 +----------- python/sglang/srt/managers/tokenizer_manager.py | 12 +----------- python/sglang/srt/managers/tp_worker.py | 13 +++---------- test/srt/test_gptqmodel_dynamic.py | 11 +---------- 6 files changed, 23 insertions(+), 53 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index e70f3af2d..09da170a8 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -137,17 +137,7 @@ def load_model(server_args, port_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - enable_multimodal=server_args.enable_multimodal, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + model_config = ModelConfig.from_server_args(server_args) model_runner = ModelRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 06b4ca267..85a4f3153 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -24,6 +24,7 @@ from transformers import PretrainedConfig from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.layers.quantization import QUANTIZATION_METHODS +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var, is_hip logger = logging.getLogger(__name__) @@ -210,6 +211,21 @@ class ModelConfig: self.hf_eos_token_id = self.get_hf_eos_token_id() self.image_token_id = getattr(self.hf_config, "image_token_id", None) + @staticmethod + def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs): + return ModelConfig( + model_path=model_path or server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + enable_multimodal=server_args.enable_multimodal, + dtype=server_args.dtype, + quantization=server_args.quantization, + **kwargs, + ) + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2158ccae6..b69bcd140 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -455,17 +455,7 @@ class Scheduler( def init_tokenizer(self): server_args = self.server_args - self.model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - enable_multimodal=server_args.enable_multimodal, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + self.model_config = ModelConfig.from_server_args(server_args) self.is_generation = self.model_config.is_generation if server_args.skip_tokenizer_init: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 185d912c0..9db63c881 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -165,17 +165,7 @@ class TokenizerManager: # Read model args self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name - self.model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - enable_multimodal=server_args.enable_multimodal, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + self.model_config = ModelConfig.from_server_args(server_args) self.is_generation = self.model_config.is_generation self.is_image_gen = self.model_config.is_image_gen diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 6c7eb6535..faed34665 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -65,20 +65,13 @@ class TpModelWorker: self.pp_rank = pp_rank # Init model and tokenizer - self.model_config = ModelConfig( - ( + self.model_config = ModelConfig.from_server_args( + server_args, + model_path=( server_args.model_path if not is_draft_worker else server_args.speculative_draft_model_path ), - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - enable_multimodal=server_args.enable_multimodal, - dtype=server_args.dtype, - quantization=server_args.quantization, is_draft_model=is_draft_worker, ) diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index 54dbaf496..27ccd9a4b 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -43,16 +43,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): pass server_args = ServerArgs(model_path=model_path, dtype=torch.float16) - model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + model_config = ModelConfig.from_server_args(server_args) load_config = LoadConfig() device_config = DeviceConfig("cuda")