From 094541296ed37265062c7cd1bd2302d0790cfd5b Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Wed, 11 Feb 2026 12:28:36 +0800 Subject: [PATCH] add deepseekv3 --- .../model_executor/models/deepseek_v2.py | 90 ++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py index d70cbfd..e45f792 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm_mlu.model_executor.layers.feed_forward import FeedForward from vllm_mlu.mlu_hijack_utils import MluHijackObject from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp +from vllm import _mlu_ops as mlu_ops from vllm.utils import print_warning_once from vllm.model_executor.models.utils import is_pp_missing_parameter from vllm_mlu.model_executor.models.layer_utils import quant_fusion_with_rmsnorm @@ -77,6 +78,12 @@ class DeepseekV2MoE(SparseMoeMlp): bias=False, quant_config=None, prefix=f"{prefix}.gate") + if getattr(config, "topk_method", None) == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -104,6 +111,7 @@ class DeepseekV2MoE(SparseMoeMlp): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) @@ -113,9 +121,25 @@ class DeepseekV2MoE(SparseMoeMlp): Modify by vllm_mlu ============================= @brief: replace experts() with forward_experts, which defined by SparseMoeMlp. + For noaux_tc (DeepSeek V3), do manual routing with e_score_correction_bias. ''' - final_hidden_states = self.forward_experts( - hidden_states, router_logits) * self.routed_scaling_factor + if self.gate.e_score_correction_bias is not None: + # noaux_tc routing: softmax → add bias for topk selection → use original scores + scores = router_logits.float().softmax(dim=-1) + scores_for_choice = scores + self.gate.e_score_correction_bias.unsqueeze(0) + topk_weights, topk_indices = torch.topk( + scores_for_choice, k=self.top_k, dim=-1) + # Use original softmax scores (without bias) as weights + topk_weights = scores.gather(1, topk_indices) + if self.renormalize: + topk_weights = topk_weights / topk_weights.sum( + dim=-1, keepdim=True) + final_hidden_states = self.forward_experts_with_precomputed_routing( + hidden_states, topk_weights, topk_indices + ) * self.routed_scaling_factor + else: + final_hidden_states = self.forward_experts( + hidden_states, router_logits) * self.routed_scaling_factor ''' ================== End of MLU Hijack @@ -129,6 +153,55 @@ class DeepseekV2MoE(SparseMoeMlp): return final_hidden_states.view(num_tokens, hidden_dim) + def forward_experts_with_precomputed_routing( + self, hidden_states, topk_weights, topk_indices + ): + """使用预计算的路由结果执行 MoE 前向传播""" + self.pack_params() + ori_input_shape = hidden_states.shape + expert_num = self.num_total_experts + expert_size = self.w13.size(0) + max_m = hidden_states.shape[0] + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + reduce_weight = topk_weights.to(torch.float32) + expert_id = topk_indices.to(torch.int32) + + # gen_idx + expand_idx, combine_idx, token_count, cusum_token_count = ( + mlu_ops.moe_gen_idx(expert_id, expert_num) + ) + + start_expert_id = self.start_expert_id + # gemm1 + expand_hidden_states = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, + start_expert_id, expert_size + ) + gemm1_out = mlu_ops.group_gemm( + expand_hidden_states, self.w13, + token_count[start_expert_id:start_expert_id + expert_size], + None, None, None, None, max_m + ) + # activation + act_out = mlu_ops.moe_active( + gemm1_out, self.hidden_act, self.is_gated, None, self.b13, + cusum_token_count, start_expert_id, expert_size + ) + # gemm2 + gemm2_out = mlu_ops.group_gemm( + act_out, self.w2, + token_count[start_expert_id:start_expert_id + expert_size], + None, None, None, None, max_m + ) + # combine + output = mlu_ops.moe_combine_result( + gemm2_out, reduce_weight, combine_idx, + None, cusum_token_count, start_expert_id, + expert_size, self.b2 + ) + return output.view(ori_input_shape) + def forward_prefill( self, positions: torch.Tensor, @@ -491,6 +564,15 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2DecoderLayer__init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) +def get_spec_layer_idx_from_weight_name(config, weight_name): + num_nextn = getattr(config, "num_nextn_predict_layers", 0) + if num_nextn and num_nextn > 0: + layer_idx = config.num_hidden_layers + for i in range(num_nextn): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): + return layer_idx + i + return None + def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ''' ============================= @@ -530,6 +612,10 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weig for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + # Skip MTP speculative decoding layer weights + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue ''' ============================= Modify by vllm_mlu