From ed54bf9d19f135679a1acd59c0d5f8ed8bc0bae3 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Sun, 15 Jun 2025 02:56:29 +0800 Subject: [PATCH] [fix] fix dsv3 weight loader tqdm and simplify shared experts fusion (#7181) --- python/sglang/srt/models/deepseek_v2.py | 107 +++--------------------- 1 file changed, 11 insertions(+), 96 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 869756568..0512fba87 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1033,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module): attn_bmm_output = attn_output.new_empty( (self.num_local_heads, aligned_m, self.v_head_dim) ) - deep_gemm_grouped_gemm_nt_f8f8bf16_masked( + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( (attn_output_val, attn_output_scale), (self.w_vc, self.w_scale_v), attn_bmm_output, @@ -2008,101 +2008,6 @@ class DeepseekV2ForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - if self.num_fused_shared_experts > 0: - assert self.num_fused_shared_experts == 1 - weights_list = list(weights) - weights_dict = dict(weights_list) - if self.quant_config is not None: - if self.quant_config.get_name() == "w8a8_int8": - suffix_list = [ - "down_proj.weight", - "down_proj.weight_scale", - "gate_proj.weight", - "gate_proj.weight_scale", - "up_proj.weight", - "up_proj.weight_scale", - ] - elif ( - self.quant_config.get_name() == "fp8" - or self.quant_config.get_name() == "blockwise_int8" - ): - suffix_list = [ - "down_proj.weight", - "down_proj.weight_scale_inv", - "gate_proj.weight", - "gate_proj.weight_scale_inv", - "up_proj.weight", - "up_proj.weight_scale_inv", - ] - elif self.quant_config.get_name() == "awq": - suffix_list = [ - "down_proj.qweight", - "down_proj.qzeros", - "down_proj.scales", - "gate_proj.qweight", - "gate_proj.qzeros", - "gate_proj.scales", - "up_proj.qweight", - "up_proj.qzeros", - "up_proj.scales", - ] - elif self.quant_config.get_name() == "modelopt_fp4": - suffix_list = [ - "down_proj.weight", - "down_proj.weight_scale", - "down_proj.weight_scale_2", - "down_proj.input_scale", - "gate_proj.weight", - "gate_proj.weight_scale", - "gate_proj.weight_scale_2", - "gate_proj.input_scale", - "up_proj.weight", - "up_proj.weight_scale", - "up_proj.weight_scale_2", - "up_proj.input_scale", - ] - else: - raise ValueError( - f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}." - ) - else: - suffix_list = [ - "down_proj.weight", - "gate_proj.weight", - "up_proj.weight", - ] - names_to_remove = [] - - moe_layers = ( - range( - self.config.first_k_dense_replace, - self.config.num_hidden_layers, - self.config.moe_layer_freq, - ) - if not is_nextn - else [nextn_layer_id] - ) - - for moe_layer in tqdm( - moe_layers, - desc=f"Cloning {self.num_fused_shared_experts} " - "shared expert into MoE", - ): - for suffix in suffix_list: - shared_expert_weight_name = ( - f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" - ) - weights_list.append( - ( - f"model.layers.{moe_layer}." - f"mlp.experts." - f"{self.config.n_routed_experts + 0}" - f".{suffix}", - weights_dict[shared_expert_weight_name], - ) - ) - names_to_remove += [shared_expert_weight_name] - weights = [w for w in weights_list if w[0] not in names_to_remove] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -2128,9 +2033,19 @@ class DeepseekV2ForCausalLM(nn.Module): "hnorm", ] + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + logger.info("Shared experts fusion optimization enabled.") + params_dict = dict(self.named_parameters()) weight_names = [] for name, loaded_weight in weights: + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + weight_names.append(name) if not is_nextn: