[feat][torchair] support super kernel feat for quantized dsr1 (#3485)

### What this PR does / why we need it?
Port #1916 and #2157 to master branch to fuse operators in deepseek moe
layers, which can reduce scheduling overhead on devices. Note that this
feature is valid only when `tp_size = 1` and
`multistream_overlap_shared_expert` is enabled with torchair graph mode.

### Does this PR introduce _any_ user-facing change?
Users can enable this feature with `--additional-config
'{"torchair_graph_config":{"enabled":true, "enable_super_kernel":true},
"multistream_overlap_shared_expert":true}'`.

### How was this patch tested?
E2E deepseek serving with 2P1D disaggregated prefill scenarios.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2025-10-20 20:04:37 +08:00
committed by GitHub
parent 70bef33f13
commit 068ed706c8
8 changed files with 138 additions and 86 deletions

View File

@@ -328,14 +328,22 @@ class TorchairDeepseekV2MoE(nn.Module):
ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel
self.params_dtype = torch.float32 if self.enable_super_kernel else \
torch.get_default_dtype()
# Converting gate weight to fp32 is to adapt to the super kernel feature.
# Super kernel feature currently cannot fuse operators such as cast, stridedslice, and add.
# In the moe stage, Cast will interrupt the fusion of the super kernel. To avoid this problem,
# modifications will be made in the initialization stage.
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=self.params_dtype,
prefix=f"{prefix}.gate")
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts))
torch.empty(config.n_routed_experts, dtype=self.params_dtype))
else:
self.gate.e_score_correction_bias = None

View File

@@ -48,7 +48,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
super_kernel)
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
get_ascend_soc_version,
@@ -990,6 +991,7 @@ class TorchairAscendFusedMoE(FusedMoE):
)
TorchairAscendFusedMoE.moe_counter += 1
self.moe_instance_id = TorchairAscendFusedMoE.moe_counter
self.prefix = prefix
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -1096,6 +1098,7 @@ class TorchairAscendFusedMoE(FusedMoE):
self.multistream_overlap_shared_expert = \
ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
@@ -1192,16 +1195,24 @@ class TorchairAscendFusedMoE(FusedMoE):
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicFusedMoEMethod
running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2
if self.multistream_overlap_shared_expert:
if not self.rm_router_logits:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
TorchairAscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
with super_kernel(self.prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
if not self.rm_router_logits:
if self.enable_super_kernel:
router_logits, _ = gate(hidden_states.float())
else:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
TorchairAscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
if shared_experts:
if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2:
@@ -1305,6 +1316,8 @@ class TorchairAscendFusedMoE(FusedMoE):
mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
prefix=self.prefix,
running_in_super_kernel=running_in_super_kernel,
)
if shared_experts:

View File

@@ -26,7 +26,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
super_kernel)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version,
is_enable_nz,
@@ -927,6 +928,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
prefix: str = "",
running_in_super_kernel: bool = False,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
@@ -934,55 +937,59 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
with super_kernel(prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
topk_weights = topk_weights.to(x.dtype)
if fused_moe_state == FusedMoEState.AllGatherEP:
return torchair_fused_experts_with_allgather(
hidden_states=x,
@@ -995,25 +1002,28 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
top_k=top_k,
expert_map=expert_map)
elif fused_moe_state == FusedMoEState.MC2:
return torchair_fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None),
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
dynamic_eplb=self.dynamic_eplb)
with super_kernel(prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
return torchair_fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None),
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
dynamic_eplb=self.dynamic_eplb)
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
import torch
import torch_npu
from torchair.scope import super_kernel as _super_kernel
try:
# Recent release of torchair has moved these ops to `.scope`.
@@ -231,3 +232,7 @@ def torchair_ops_patch():
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
def super_kernel(prefix: str, option: str, enabled: bool = True):
return _super_kernel(prefix, option) if enabled else nullcontext()