[Auto Sync] Update load_config.py, model_config.py, configu... (20250923) (#10825)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,6 +24,8 @@ class LoadFormat(str, enum.Enum):
|
||||
JAX = "jax"
|
||||
REMOTE = "remote"
|
||||
REMOTE_INSTANCE = "remote_instance"
|
||||
RDMA = "rdma"
|
||||
LOCAL_CACHED = "local_cached"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -47,6 +49,7 @@ class LoadConfig:
|
||||
checkpoints.
|
||||
decryption_key_file: If set, decrypts the output files with a password read
|
||||
from this file (after PBKDF2).
|
||||
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
@@ -54,6 +57,7 @@ class LoadConfig:
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
decryption_key_file: Optional[str] = None
|
||||
decrypt_max_concurrency: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
|
||||
@@ -75,7 +75,10 @@ class ModelConfig:
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
|
||||
# TODO: remove these fields
|
||||
self.tp_rank = tp_rank
|
||||
self.remote_instance_weight_loader_seed_instance_ip = (
|
||||
remote_instance_weight_loader_seed_instance_ip
|
||||
@@ -87,12 +90,12 @@ class ModelConfig:
|
||||
remote_instance_weight_loader_send_weights_group_ports
|
||||
)
|
||||
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
# Get hf config
|
||||
self._maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
if override_config_file and override_config_file.strip():
|
||||
kwargs["_configuration_file"] = override_config_file.strip()
|
||||
|
||||
self.hf_config = get_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@@ -100,7 +103,7 @@ class ModelConfig:
|
||||
model_override_args=self.model_override_args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.hf_generation_config = get_generation_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@@ -108,7 +111,25 @@ class ModelConfig:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
# Set enable_multimodal
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||
enable_multimodal = False
|
||||
logger.info(
|
||||
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
||||
)
|
||||
else:
|
||||
enable_multimodal = True
|
||||
|
||||
# Config draft model
|
||||
self._config_draft_model()
|
||||
|
||||
# Check model type
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
@@ -124,20 +145,73 @@ class ModelConfig:
|
||||
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
||||
)
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
enable_multimodal
|
||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||
enable_multimodal = False
|
||||
logger.info(
|
||||
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
||||
)
|
||||
else:
|
||||
enable_multimodal = True
|
||||
# Derive context length and model shapes
|
||||
self._derive_context_length(context_length)
|
||||
self._derive_model_shapes()
|
||||
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self._get_hf_eos_token_id()
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(
|
||||
self.hf_config, "image_token_id", None
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(
|
||||
server_args: ServerArgs,
|
||||
model_path: str = None,
|
||||
model_revision: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
return ModelConfig(
|
||||
model_path=model_path or server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=model_revision or server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _config_draft_model(self):
|
||||
is_draft_model = self.is_draft_model
|
||||
|
||||
if (
|
||||
is_draft_model
|
||||
@@ -172,31 +246,10 @@ class ModelConfig:
|
||||
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
||||
self.hf_config.num_nextn_predict_layers = 1
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
enable_multimodal
|
||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Derive context length
|
||||
def _derive_context_length(self, context_length: int):
|
||||
is_draft_model = self.is_draft_model
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
|
||||
if context_length is not None:
|
||||
if context_length > derived_context_len:
|
||||
reason = "Target model's" if is_draft_model else "User-specified"
|
||||
@@ -224,6 +277,10 @@ class ModelConfig:
|
||||
else:
|
||||
self.context_len = derived_context_len
|
||||
|
||||
# Transfer context_len to HuggingFace config so models can access it
|
||||
self.hf_config.context_len = self.context_len
|
||||
|
||||
def _derive_model_shapes(self):
|
||||
# Unify the config keys for hf_text_config
|
||||
self.head_dim = getattr(
|
||||
self.hf_text_config,
|
||||
@@ -318,45 +375,6 @@ class ModelConfig:
|
||||
)
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(
|
||||
self.hf_config, "image_token_id", None
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(
|
||||
server_args: ServerArgs,
|
||||
model_path: str = None,
|
||||
model_revision: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
return ModelConfig(
|
||||
model_path=model_path or server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=model_revision or server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_total_num_attention_heads(self) -> int:
|
||||
return self.num_attention_heads
|
||||
|
||||
@@ -591,7 +609,7 @@ class ModelConfig:
|
||||
"sparse_attention_enabled"
|
||||
] = True
|
||||
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids is not None:
|
||||
# it can be either int or list of int
|
||||
@@ -611,7 +629,7 @@ class ModelConfig:
|
||||
eos_ids = eos_ids | generation_eos_ids
|
||||
return eos_ids
|
||||
|
||||
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
"""
|
||||
Pull the model config files to a temporary
|
||||
directory in case of remote.
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user