From 87975fa058fe3f90d204ded42a08989a8dcb413e Mon Sep 17 00:00:00 2001 From: ApsarasX Date: Tue, 29 Apr 2025 17:06:19 +0800 Subject: [PATCH] [Bugfix] Fix early return in CustomDeepseekV2MoE.forward during profile_run (#682) ### What this PR does / why we need it? Fix #674 to avoild KVCache overallocation and OOM risks. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Signed-off-by: ApsarasX --- vllm_ascend/models/deepseek_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 76f9468..e33e8df 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -143,15 +143,16 @@ class CustomDeepseekV2MoE(nn.Module): attn_metadata = get_forward_context().attn_metadata if attn_metadata is None: # for profile run - return hidden_states + is_prefill = True + else: + is_prefill = attn_metadata.num_prefills > 0 num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) - if (self.tp_size > 1 and self.enable_mc2 - and attn_metadata.num_prefills == 0): + if (self.tp_size > 1 and self.enable_mc2 and not is_prefill): chunks = torch.chunk(hidden_states, get_tp_group().world_size, dim=0) @@ -159,8 +160,7 @@ class CustomDeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - is_prefill = True if attn_metadata.num_prefills > 0 else False - # is_prefill = attn_metadata.num_prefills > 0 + final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits,