[perf]Support MOE Multi-stream in Deepseek (#947)
### What this PR does / why we need it? Support MOE inner Multi-stream for Deepseek. This feature requires graph mode with mc2 enabled. --------- Signed-off-by: David9857 <985700846@qq.com>
This commit is contained in:
@@ -39,19 +39,18 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
global_bs = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs = {
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
@@ -81,9 +80,9 @@ def fused_experts_with_mc2(
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage1_kwargs)
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
@@ -119,7 +118,7 @@ def fused_experts_with_mc2(
|
||||
down_out_list = torch.cat(down_out_list, dim=0)
|
||||
|
||||
# moeCombine
|
||||
kwargs = {
|
||||
kwargs_mc2 = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expand_idx": expand_idx,
|
||||
@@ -141,9 +140,9 @@ def fused_experts_with_mc2(
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage3_kwargs)
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -675,7 +674,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
**kwargs)
|
||||
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -772,6 +772,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_shared_expert = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_shared_expert
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
@@ -818,7 +820,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
router_logits: torch.Tensor,
|
||||
is_prefill: bool,
|
||||
enable_force_load_balance: bool = False,
|
||||
top_k=None):
|
||||
top_k=None,
|
||||
**kwargs):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if top_k:
|
||||
@@ -862,7 +865,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance)
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
**kwargs)
|
||||
|
||||
if self.enable_multistream_shared_expert and not is_prefill:
|
||||
hidden_states, shared_output = hidden_states
|
||||
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
@@ -886,4 +893,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
if self.enable_multistream_shared_expert and not is_prefill:
|
||||
return hidden_states, shared_output
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user