diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 76f9468..e33e8df 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -143,15 +143,16 @@ class CustomDeepseekV2MoE(nn.Module): attn_metadata = get_forward_context().attn_metadata if attn_metadata is None: # for profile run - return hidden_states + is_prefill = True + else: + is_prefill = attn_metadata.num_prefills > 0 num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) - if (self.tp_size > 1 and self.enable_mc2 - and attn_metadata.num_prefills == 0): + if (self.tp_size > 1 and self.enable_mc2 and not is_prefill): chunks = torch.chunk(hidden_states, get_tp_group().world_size, dim=0) @@ -159,8 +160,7 @@ class CustomDeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - is_prefill = True if attn_metadata.num_prefills > 0 else False - # is_prefill = attn_metadata.num_prefills > 0 + final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits,