[Bugfix] Fix early return in CustomDeepseekV2MoE.forward during profile_run (#682)

### What this PR does / why we need it?

Fix #674 to avoild KVCache overallocation and OOM risks.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
ApsarasX
2025-04-29 17:06:19 +08:00
committed by GitHub
parent 7aee9228f0
commit 87975fa058

View File

@@ -143,15 +143,16 @@ class CustomDeepseekV2MoE(nn.Module):
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# for profile run
return hidden_states
is_prefill = True
else:
is_prefill = attn_metadata.num_prefills > 0
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if (self.tp_size > 1 and self.enable_mc2
and attn_metadata.num_prefills == 0):
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
chunks = torch.chunk(hidden_states,
get_tp_group().world_size,
dim=0)
@@ -159,8 +160,7 @@ class CustomDeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
is_prefill = True if attn_metadata.num_prefills > 0 else False
# is_prefill = attn_metadata.num_prefills > 0
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,