Speed up when having padding tokens in DeepEP (#6175)
This commit is contained in:
@@ -165,7 +165,7 @@ class DeepseekV2MLP(nn.Module):
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
|
||||
def forward(self, x, forward_batch: Optional[ForwardBatch] = None):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
@@ -287,12 +287,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
return self.forward_normal(hidden_states)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
@@ -309,8 +309,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
return final_hidden_states
|
||||
|
||||
def forward_deepep(
|
||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
forward_mode = forward_batch.forward_mode
|
||||
shared_output = None
|
||||
if (
|
||||
forward_mode is not None
|
||||
@@ -330,6 +331,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
)
|
||||
else:
|
||||
topk_idx = torch.full(
|
||||
@@ -1339,7 +1341,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
and (not self.info.is_sparse)
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states += residual
|
||||
|
||||
Reference in New Issue
Block a user