From dd408ee4815ce505d48f1f590058aab41a856aaf Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 30 Apr 2025 07:25:40 +0800 Subject: [PATCH] Auto set draft model path for MTP (#5793) --- python/sglang/srt/configs/model_config.py | 7 + python/sglang/srt/managers/tp_worker.py | 1 + .../sglang/srt/model_executor/model_runner.py | 13 +- python/sglang/srt/models/deepseek_nextn.py | 258 +----------------- python/sglang/srt/models/deepseek_v2.py | 91 ++++-- python/sglang/srt/server_args.py | 32 ++- 6 files changed, 115 insertions(+), 287 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 13dacee37..787c367f6 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -47,6 +47,7 @@ class ModelConfig: dtype: str = "auto", quantization: Optional[str] = None, override_config_file: Optional[str] = None, + is_draft_model: bool = False, ) -> None: self.model_path = model_path @@ -85,6 +86,12 @@ class ModelConfig: else: enable_multimodal = True + if ( + is_draft_model + and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM" + ): + self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + # Check model type self.is_generation = is_generation_model( self.hf_config.architectures, is_embedding diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a07dbfb07..c7666ffc6 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -71,6 +71,7 @@ class TpModelWorker: enable_multimodal=server_args.enable_multimodal, dtype=server_args.dtype, quantization=server_args.quantization, + is_draft_model=is_draft_worker, ) self.model_runner = ModelRunner( model_config=self.model_config, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3eb798651..231428b83 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -692,9 +692,14 @@ class ModelRunner: self.device, self.gpu_id, distributed=self.tp_size > 1 ) if self.use_mla_backend: + num_layers = ( + self.model_config.num_hidden_layers + if not self.is_draft_worker + else self.model_config.hf_config.num_nextn_predict_layers + ) cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) - * self.model_config.num_hidden_layers + * num_layers * torch._utils._element_size(self.kv_cache_dtype) ) else: @@ -809,7 +814,11 @@ class ModelRunner: dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, - layer_num=self.model_config.num_hidden_layers, + layer_num=( + self.model_config.num_hidden_layers + if not self.is_draft_worker + else self.model_config.hf_config.num_nextn_predict_layers + ), device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 4849177af..d01bc3ae9 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - if hasattr(self.config, "num_nextn_predict_layers"): - num_nextn_layers = self.config.num_nextn_predict_layers - assert num_nextn_layers == 1, "Only 1 nextn layer is supportted" - assert num_nextn_layers == self.config.num_hidden_layers - else: - raise ValueError("num_nextn_predict_layers is not in the config") - - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - if self.n_share_experts_fusion > 0: - logger.info( - f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN" - ) - weights_list = list(weights) - weights_dict = dict(weights_list) - if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8": - suffix_list = [ - "down_proj.weight", - "down_proj.weight_scale", - "gate_proj.weight", - "gate_proj.weight_scale", - "up_proj.weight", - "up_proj.weight_scale", - ] - else: - suffix_list = [ - "down_proj.weight", - "down_proj.weight_scale_inv", - "gate_proj.weight", - "gate_proj.weight_scale_inv", - "up_proj.weight", - "up_proj.weight_scale_inv", - ] - names_to_remove = [] - for suffix in suffix_list: - shared_expert_weight_name = ( - f"model.layers.0.mlp.shared_experts.{suffix}" - ) - for num_repeat in range(self.n_share_experts_fusion): - weights_list.append( - ( - f"model.layers.0." - f"mlp.experts." - f"{self.config.n_routed_experts + num_repeat}" - f".{suffix}", - weights_dict[shared_expert_weight_name], - ) - ) - names_to_remove += [shared_expert_weight_name] - weights = [w for w in weights_list if w[0] not in names_to_remove] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE - expert_params_mapping = MoEImpl.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts + self.n_share_experts_fusion, - ) - - # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None - fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( - self.config.q_lora_rank is not None - ) - cached_a_proj = {} if fuse_qkv_a_proj else None - - nextn_layer_prefix = "model.layers.0" - nextn_spec_weight_names = [ - "shared_head.norm", - "eh_proj", - "enorm", - "hnorm", - ] - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if not name.startswith(nextn_layer_prefix): - continue - - # Use shared head and embed weights from target model - if "shared_head.head" in name or "embed_tokens" in name: - continue - - is_decoder = True - # For nextn specific weights - for weight_name in nextn_spec_weight_names: - if weight_name in name: - name = name.replace(nextn_layer_prefix, "model") - is_decoder = False - break - # For decoder layer weights - if is_decoder: - name = name.replace(nextn_layer_prefix, "model.decoder") - - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - # Handle fused_qkv_a_proj - if fuse_qkv_a_proj and ( - "q_a_proj" in name or "kv_a_proj_with_mqa" in name - ): - cached_a_proj[name] = loaded_weight - q_a_proj_name = ( - name - if "q_a_proj" in name - else name.replace("kv_a_proj_with_mqa", "q_a_proj") - ) - kv_a_proj_name = ( - name - if "kv_a_proj_with_mqa" in name - else name.replace("q_a_proj", "kv_a_proj_with_mqa") - ) - - # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter - if ( - q_a_proj_name in cached_a_proj - and kv_a_proj_name in cached_a_proj - ): - - q_a_proj_weight = cached_a_proj[q_a_proj_name] - kv_a_proj_weight = cached_a_proj[kv_a_proj_name] - fused_weight = torch.cat( - [q_a_proj_weight, kv_a_proj_weight], dim=0 - ) - - param_name = name.replace( - "q_a_proj", "fused_qkv_a_proj_with_mqa" - ) - param = params_dict[param_name] - - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, fused_weight) - cached_a_proj.pop(q_a_proj_name) - cached_a_proj.pop(kv_a_proj_name) - else: - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - - self_attn = self.model.decoder.self_attn - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T - else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant(weight, weight_scale, weight_block_size).to( - torch.bfloat16 - ) - else: - # channel-wise int8 need it - assert hasattr(self_attn.kv_b_proj, "weight_scale") - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( - torch.bfloat16 - ) - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None: - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 + super().load_weights(weights, is_nextn=True) EntryClass = [DeepseekV3ForCausalLMNextN] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c24c098c0..01063a298 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module): input_ids, hidden_states, self.lm_head, forward_batch ) - def post_load_weights(self): + def post_load_weights(self, is_nextn=False): # Perform post-processing after loading weights - for layer_id in range(self.config.num_hidden_layers): - self_attn = self.model.layers[layer_id].self_attn + layer_ids = ( + range(self.config.num_hidden_layers) + if not is_nextn + else [self.config.num_hidden_layers] + ) + for layer_id in layer_ids: + self_attn = ( + self.model.layers[layer_id].self_attn + if not is_nextn + else self.model.decoder.self_attn + ) if hasattr(self_attn.kv_b_proj, "qweight"): # AWQ compatible if _is_cuda: @@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module): self_attn.w_vc = w_vc.contiguous() self_attn.use_deep_gemm_bmm = True - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supportted" + # compatible with old design + nextn_layer_id = ( + 0 + if self.config.num_hidden_layers == 1 + else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module): "up_proj.weight_scale_inv", ] names_to_remove = [] - for moe_layer in tqdm( + + moe_layers = ( range( self.config.first_k_dense_replace, self.config.num_hidden_layers, self.config.moe_layer_freq, - ), + ) + if not is_nextn + else [nextn_layer_id] + ) + + for moe_layer in tqdm( + moe_layers, desc=f"Cloning {self.n_share_experts_fusion} " "replicas of the shared expert into MoE", ): @@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module): ) cached_a_proj = {} if fuse_qkv_a_proj else None + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names = [ + "shared_head.norm", + "eh_proj", + "enorm", + "hnorm", + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # TODO(HandH1998): Modify it when nextn is supported. - if hasattr(self.config, "num_nextn_predict_layers"): - num_nextn_layers = self.config.num_nextn_predict_layers - if num_nextn_layers > 0 and name.startswith("model.layers"): - name_list = name.split(".") - if ( - len(name_list) >= 3 - and int(name_list[2]) >= self.config.num_hidden_layers - ): - continue + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) >= self.config.num_hidden_layers + ): + continue + else: + if not name.startswith(nextn_layer_prefix): + continue + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: @@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) - self.post_load_weights() + self.post_load_weights(is_nextn=is_nextn) def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index eace0ade5..d06ae705f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -22,7 +22,7 @@ import random import tempfile from typing import List, Literal, Optional -from sglang.srt.hf_transformers_utils import check_gguf_file +from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( configure_ipv6, @@ -333,6 +333,14 @@ class ServerArgs: "eagle speculative decoding." ) + model_arch = get_model_arch(self) + + # Auto set draft_model_path DeepSeek-V3/R1 + if self.speculative_draft_model_path is None and model_arch in [ + "DeepseekV3ForCausalLM" + ]: + self.speculative_draft_model_path = self.model_path + # Auto choose parameters if self.speculative_num_steps is None: assert ( @@ -343,7 +351,7 @@ class ServerArgs: self.speculative_num_steps, self.speculative_eagle_topk, self.speculative_num_draft_tokens, - ) = auto_choose_speculative_params(self) + ) = auto_choose_speculative_params(model_arch) if self.page_size > 1 and self.speculative_eagle_topk > 1: self.speculative_eagle_topk = 1 @@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action): raise ValueError(self.help) -def auto_choose_speculative_params(self: ServerArgs): +def get_model_arch(args: ServerArgs): + hf_config = get_config( + args.model_path, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + model_override_args=json.loads(args.json_model_override_args), + ) + return hf_config.architectures[0] + + +def auto_choose_speculative_params(arch: str): """ Automatically choose the parameters for speculative decoding. You can tune them on your own models and prompts with scripts/playground/bench_speculative.py """ - config_path = os.path.join(self.model_path, "config.json") - if not os.path.exists(config_path): - raise ValueError(f"{config_path} is not found.") - - config = json.load(open(config_path)) - - arch = config.get("architectures", ["Unknown"])[0] - if arch in ["LlamaForCausalLM"]: # The default value for llama return (5, 4, 8)