[perf]Support MOE Multi-stream in Deepseek (#947)
### What this PR does / why we need it? Support MOE inner Multi-stream for Deepseek. This feature requires graph mode with mc2 enabled. --------- Signed-off-by: David9857 <985700846@qq.com>
This commit is contained in:
@@ -20,7 +20,8 @@ from typing import Any, Callable, Dict, Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.distributed import GroupCoordinator
|
||||
import torchair as tng # type: ignore
|
||||
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
@@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
group_list_type: int = 1,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
@@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
shared_experts = kwargs.get('shared_experts', None)
|
||||
if shared_experts:
|
||||
shared_gate_up = kwargs.get('shared_gate_up', None)
|
||||
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
|
||||
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=shared_gate_up,
|
||||
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=shared_dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -100,25 +119,39 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
|
||||
if shared_experts:
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_x, hidden_states)
|
||||
shared_output = torch_npu.npu_quant_matmul(
|
||||
shared_x,
|
||||
shared_experts.down_proj.weight,
|
||||
shared_experts.down_proj.weight_scale,
|
||||
pertoken_scale=shared_dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
|
||||
shared_output = tensor_model_parallel_all_reduce(shared_output)
|
||||
if shared_experts:
|
||||
return hidden_states, shared_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: str = "",
|
||||
**kwargs) -> torch.Tensor:
|
||||
global_bs = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
# hidden_states = hidden_states.bfloat16()
|
||||
kwargs = {
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
@@ -149,9 +182,27 @@ def fused_experts_with_mc2(
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage1_kwargs)
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
|
||||
shared_experts = kwargs.get('shared_experts', None)
|
||||
if shared_experts:
|
||||
shared_hidden_states = kwargs.get('shared_hidden_states', None)
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
|
||||
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
shared_hidden_states)
|
||||
shared_gate_up = torch_npu.npu_quant_matmul(
|
||||
shared_x,
|
||||
shared_experts.gate_up_proj.weight,
|
||||
shared_experts.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
kwargs.update({
|
||||
"shared_gate_up": shared_gate_up,
|
||||
"shared_dynamic_scale": shared_dynamic_scale,
|
||||
})
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
@@ -166,10 +217,15 @@ def fused_experts_with_mc2(
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_token_nums,
|
||||
dynamic_scale=dynamic_scale)
|
||||
dynamic_scale=dynamic_scale,
|
||||
**kwargs)
|
||||
|
||||
multi_stream = isinstance(down_out_list, tuple)
|
||||
if multi_stream:
|
||||
down_out_list, shared_output = down_out_list
|
||||
|
||||
# moeCombine
|
||||
kwargs = {
|
||||
kwargs_mc2 = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expand_idx": expand_idx,
|
||||
@@ -193,10 +249,12 @@ def fused_experts_with_mc2(
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage3_kwargs)
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
if multi_stream:
|
||||
return hidden_states, shared_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -634,7 +692,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
**kwargs)
|
||||
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
|
||||
Reference in New Issue
Block a user