[Bugfix] Fix early return in CustomDeepseekV2MoE.forward during profile_run (#682)
### What this PR does / why we need it? Fix #674 to avoild KVCache overallocation and OOM risks. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -143,15 +143,16 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
return hidden_states
|
||||
is_prefill = True
|
||||
else:
|
||||
is_prefill = attn_metadata.num_prefills > 0
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if (self.tp_size > 1 and self.enable_mc2
|
||||
and attn_metadata.num_prefills == 0):
|
||||
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
|
||||
chunks = torch.chunk(hidden_states,
|
||||
get_tp_group().world_size,
|
||||
dim=0)
|
||||
@@ -159,8 +160,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
is_prefill = True if attn_metadata.num_prefills > 0 else False
|
||||
# is_prefill = attn_metadata.num_prefills > 0
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
|
||||
Reference in New Issue
Block a user