[Bugfix] Fix early return in CustomDeepseekV2MoE.forward during profile_run (#682)
### What this PR does / why we need it? Fix #674 to avoild KVCache overallocation and OOM risks. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -143,15 +143,16 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# for profile run
|
# for profile run
|
||||||
return hidden_states
|
is_prefill = True
|
||||||
|
else:
|
||||||
|
is_prefill = attn_metadata.num_prefills > 0
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
if self.n_shared_experts is not None:
|
if self.n_shared_experts is not None:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
|
||||||
if (self.tp_size > 1 and self.enable_mc2
|
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
|
||||||
and attn_metadata.num_prefills == 0):
|
|
||||||
chunks = torch.chunk(hidden_states,
|
chunks = torch.chunk(hidden_states,
|
||||||
get_tp_group().world_size,
|
get_tp_group().world_size,
|
||||||
dim=0)
|
dim=0)
|
||||||
@@ -159,8 +160,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
is_prefill = True if attn_metadata.num_prefills > 0 else False
|
|
||||||
# is_prefill = attn_metadata.num_prefills > 0
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
Reference in New Issue
Block a user