diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index f31970622..37722c492 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -261,6 +261,9 @@ class ModelConfig: self.num_key_value_heads = self.num_attention_heads self.hidden_size = self.hf_text_config.hidden_size self.num_hidden_layers = self.hf_text_config.num_hidden_layers + self.num_nextn_predict_layers = getattr( + self.hf_text_config, "num_nextn_predict_layers", None + ) self.vocab_size = self.hf_text_config.vocab_size # Verify quantization diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 13555adeb..02389108a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -285,11 +285,21 @@ class ModelRunner: if architectures and not any("Llama4" in arch for arch in architectures): self.is_hybrid = self.model_config.is_hybrid = True - self.start_layer = getattr(self.model, "start_layer", 0) - self.end_layer = getattr( - self.model, "end_layer", self.model_config.num_hidden_layers + # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft + # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to + # determine the number of layers. + model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None + model_num_layers = ( + self.model_config.num_nextn_predict_layers + if self.is_draft_worker and model_has_mtp_layers + else self.model_config.num_hidden_layers ) + self.start_layer = getattr(self.model, "start_layer", 0) + self.end_layer = getattr(self.model, "end_layer", model_num_layers) self.num_effective_layers = self.end_layer - self.start_layer + assert (not model_has_mtp_layers) or ( + self.num_effective_layers == model_num_layers + ), "PP is not compatible with MTP models." # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) @@ -1178,11 +1188,7 @@ class ModelRunner: dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, - layer_num=( - self.model_config.num_hidden_layers - if not self.is_draft_worker - else self.model_config.hf_config.num_nextn_predict_layers - ), # PP is not compatible with mla backend + layer_num=self.num_effective_layers, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, start_layer=self.start_layer, @@ -1195,11 +1201,7 @@ class ModelRunner: dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, - layer_num=( - self.model_config.num_hidden_layers - if not self.is_draft_worker - else self.model_config.hf_config.num_nextn_predict_layers - ), # PP is not compatible with mla backend + layer_num=self.num_effective_layers, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, start_layer=self.start_layer,