From f47a2c67e6ef9174bbfa6d243ca59a4935d1ef57 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Sep 2025 16:48:12 -0700 Subject: [PATCH] [Auto Sync] Update load_config.py, model_config.py, configu... (20250923) (#10825) Co-authored-by: github-actions[bot] --- python/sglang/srt/configs/load_config.py | 4 + python/sglang/srt/configs/model_config.py | 182 ++++++++++-------- .../deep_gemm_wrapper/configurer.py | 2 - 3 files changed, 104 insertions(+), 84 deletions(-) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 6ac003ea4..c734bd2e6 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -24,6 +24,8 @@ class LoadFormat(str, enum.Enum): JAX = "jax" REMOTE = "remote" REMOTE_INSTANCE = "remote_instance" + RDMA = "rdma" + LOCAL_CACHED = "local_cached" @dataclass @@ -47,6 +49,7 @@ class LoadConfig: checkpoints. decryption_key_file: If set, decrypts the output files with a password read from this file (after PBKDF2). + decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit. """ load_format: Union[str, LoadFormat] = LoadFormat.AUTO @@ -54,6 +57,7 @@ class LoadConfig: model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None decryption_key_file: Optional[str] = None + decrypt_max_concurrency: int = -1 def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 46c610f00..7f9d83954 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -75,7 +75,10 @@ class ModelConfig: self.model_path = model_path self.revision = revision self.quantization = quantization + self.is_draft_model = is_draft_model self.model_impl = model_impl + + # TODO: remove these fields self.tp_rank = tp_rank self.remote_instance_weight_loader_seed_instance_ip = ( remote_instance_weight_loader_seed_instance_ip @@ -87,12 +90,12 @@ class ModelConfig: remote_instance_weight_loader_send_weights_group_ports ) - self.maybe_pull_model_tokenizer_from_remote() + # Get hf config + self._maybe_pull_model_tokenizer_from_remote() self.model_override_args = json.loads(model_override_args) kwargs = {} if override_config_file and override_config_file.strip(): kwargs["_configuration_file"] = override_config_file.strip() - self.hf_config = get_config( self.model_path, trust_remote_code=trust_remote_code, @@ -100,7 +103,7 @@ class ModelConfig: model_override_args=self.model_override_args, **kwargs, ) - + self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_generation_config = get_generation_config( self.model_path, trust_remote_code=trust_remote_code, @@ -108,7 +111,25 @@ class ModelConfig: **kwargs, ) - self.hf_text_config = get_hf_text_config(self.hf_config) + # Set enable_multimodal + if enable_multimodal is None: + mm_disabled_models = [ + "Gemma3ForConditionalGeneration", + "Llama4ForConditionalGeneration", + "Step3VLForConditionalGeneration", + ] + if self.hf_config.architectures[0] in mm_disabled_models: + enable_multimodal = False + logger.info( + f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal." + ) + else: + enable_multimodal = True + + # Config draft model + self._config_draft_model() + + # Check model type self.attention_chunk_size = getattr( self.hf_text_config, "attention_chunk_size", None ) @@ -124,20 +145,73 @@ class ModelConfig: self.hf_config.architectures, self.hf_text_config.num_hidden_layers ) ) + self.is_generation = is_generation_model( + self.hf_config.architectures, is_embedding + ) + self.is_multimodal = enable_multimodal and is_multimodal_model( + self.hf_config.architectures + ) + self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model( + self.hf_config.architectures + ) + self.is_image_gen = enable_multimodal and is_image_gen_model( + self.hf_config.architectures + ) + self.is_audio_model = enable_multimodal and is_audio_model( + self.hf_config.architectures + ) + self.is_multimodal_chunked_prefill_supported = ( + enable_multimodal + and is_multimodal_chunked_prefill_supported(self.hf_config.architectures) + ) + self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - if enable_multimodal is None: - mm_disabled_models = [ - "Gemma3ForConditionalGeneration", - "Llama4ForConditionalGeneration", - "Step3VLForConditionalGeneration", - ] - if self.hf_config.architectures[0] in mm_disabled_models: - enable_multimodal = False - logger.info( - f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal." - ) - else: - enable_multimodal = True + # Derive context length and model shapes + self._derive_context_length(context_length) + self._derive_model_shapes() + + # Verify quantization + self._verify_quantization() + + # Verify dual-chunk attention config + self._verify_dual_chunk_attention_config() + + # Cache attributes + self.hf_eos_token_id = self._get_hf_eos_token_id() + + # multimodal + self.image_token_id = getattr( + self.hf_config, "image_token_id", None + ) or getattr(self.hf_config, "image_token_index", None) + + @staticmethod + def from_server_args( + server_args: ServerArgs, + model_path: str = None, + model_revision: str = None, + **kwargs, + ): + return ModelConfig( + model_path=model_path or server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + revision=model_revision or server_args.revision, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + enable_multimodal=server_args.enable_multimodal, + dtype=server_args.dtype, + quantization=server_args.quantization, + hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, + model_impl=server_args.model_impl, + remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, + **kwargs, + ) + + def _config_draft_model(self): + is_draft_model = self.is_draft_model if ( is_draft_model @@ -172,31 +246,10 @@ class ModelConfig: self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP" self.hf_config.num_nextn_predict_layers = 1 - # Check model type - self.is_generation = is_generation_model( - self.hf_config.architectures, is_embedding - ) - self.is_multimodal = enable_multimodal and is_multimodal_model( - self.hf_config.architectures - ) - self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model( - self.hf_config.architectures - ) - self.is_image_gen = enable_multimodal and is_image_gen_model( - self.hf_config.architectures - ) - self.is_audio_model = enable_multimodal and is_audio_model( - self.hf_config.architectures - ) - self.is_multimodal_chunked_prefill_supported = ( - enable_multimodal - and is_multimodal_chunked_prefill_supported(self.hf_config.architectures) - ) - self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) - self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - - # Derive context length + def _derive_context_length(self, context_length: int): + is_draft_model = self.is_draft_model derived_context_len = get_context_length(self.hf_text_config) + if context_length is not None: if context_length > derived_context_len: reason = "Target model's" if is_draft_model else "User-specified" @@ -224,6 +277,10 @@ class ModelConfig: else: self.context_len = derived_context_len + # Transfer context_len to HuggingFace config so models can access it + self.hf_config.context_len = self.context_len + + def _derive_model_shapes(self): # Unify the config keys for hf_text_config self.head_dim = getattr( self.hf_text_config, @@ -318,45 +375,6 @@ class ModelConfig: ) self.vocab_size = self.hf_text_config.vocab_size - # Verify quantization - self._verify_quantization() - - # Verify dual-chunk attention config - self._verify_dual_chunk_attention_config() - - # Cache attributes - self.hf_eos_token_id = self.get_hf_eos_token_id() - - # multimodal - self.image_token_id = getattr( - self.hf_config, "image_token_id", None - ) or getattr(self.hf_config, "image_token_index", None) - - @staticmethod - def from_server_args( - server_args: ServerArgs, - model_path: str = None, - model_revision: str = None, - **kwargs, - ): - return ModelConfig( - model_path=model_path or server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=model_revision or server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - enable_multimodal=server_args.enable_multimodal, - dtype=server_args.dtype, - quantization=server_args.quantization, - hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, - model_impl=server_args.model_impl, - remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, - remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, - remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, - **kwargs, - ) - def get_total_num_attention_heads(self) -> int: return self.num_attention_heads @@ -591,7 +609,7 @@ class ModelConfig: "sparse_attention_enabled" ] = True - def get_hf_eos_token_id(self) -> Optional[Set[int]]: + def _get_hf_eos_token_id(self) -> Optional[Set[int]]: eos_ids = getattr(self.hf_config, "eos_token_id", None) if eos_ids is not None: # it can be either int or list of int @@ -611,7 +629,7 @@ class ModelConfig: eos_ids = eos_ids | generation_eos_ids return eos_ids - def maybe_pull_model_tokenizer_from_remote(self) -> None: + def _maybe_pull_model_tokenizer_from_remote(self) -> None: """ Pull the model config files to a temporary directory in case of remote. diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py index ab2c4191b..62073e38c 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py @@ -1,7 +1,5 @@ import logging -import torch - from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell logger = logging.getLogger(__name__)