[feat][torchair] support super kernel feat for quantized dsr1 (#3485)
### What this PR does / why we need it? Port #1916 and #2157 to master branch to fuse operators in deepseek moe layers, which can reduce scheduling overhead on devices. Note that this feature is valid only when `tp_size = 1` and `multistream_overlap_shared_expert` is enabled with torchair graph mode. ### Does this PR introduce _any_ user-facing change? Users can enable this feature with `--additional-config '{"torchair_graph_config":{"enabled":true, "enable_super_kernel":true}, "multistream_overlap_shared_expert":true}'`. ### How was this patch tested? E2E deepseek serving with 2P1D disaggregated prefill scenarios. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -26,7 +26,8 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
||||
super_kernel)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version,
|
||||
is_enable_nz,
|
||||
@@ -927,6 +928,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
prefix: str = "",
|
||||
running_in_super_kernel: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@@ -934,55 +937,59 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently is 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = torchair_select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
||||
fused_moe_state = FusedMoEState.All2All
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
with super_kernel(prefix,
|
||||
"stream-fusion=1",
|
||||
enabled=running_in_super_kernel):
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently is 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = torchair_select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
if fused_moe_state == FusedMoEState.AllGatherEP:
|
||||
return torchair_fused_experts_with_allgather(
|
||||
hidden_states=x,
|
||||
@@ -995,25 +1002,28 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
elif fused_moe_state == FusedMoEState.MC2:
|
||||
return torchair_fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_fp32,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
with super_kernel(prefix,
|
||||
"stream-fusion=1",
|
||||
enabled=running_in_super_kernel):
|
||||
return torchair_fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_fp32,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
elif fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
|
||||
Reference in New Issue
Block a user