[feat][torchair] support super kernel feat for quantized dsr1 (#3485)

### What this PR does / why we need it?
Port #1916 and #2157 to master branch to fuse operators in deepseek moe
layers, which can reduce scheduling overhead on devices. Note that this
feature is valid only when `tp_size = 1` and
`multistream_overlap_shared_expert` is enabled with torchair graph mode.

### Does this PR introduce _any_ user-facing change?
Users can enable this feature with `--additional-config
'{"torchair_graph_config":{"enabled":true, "enable_super_kernel":true},
"multistream_overlap_shared_expert":true}'`.

### How was this patch tested?
E2E deepseek serving with 2P1D disaggregated prefill scenarios.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2025-10-20 20:04:37 +08:00
committed by GitHub
parent 70bef33f13
commit 068ed706c8
8 changed files with 138 additions and 86 deletions

View File

@@ -48,7 +48,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
super_kernel)
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
get_ascend_soc_version,
@@ -990,6 +991,7 @@ class TorchairAscendFusedMoE(FusedMoE):
)
TorchairAscendFusedMoE.moe_counter += 1
self.moe_instance_id = TorchairAscendFusedMoE.moe_counter
self.prefix = prefix
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -1096,6 +1098,7 @@ class TorchairAscendFusedMoE(FusedMoE):
self.multistream_overlap_shared_expert = \
ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
@@ -1192,16 +1195,24 @@ class TorchairAscendFusedMoE(FusedMoE):
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicFusedMoEMethod
running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2
if self.multistream_overlap_shared_expert:
if not self.rm_router_logits:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
TorchairAscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
with super_kernel(self.prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
if not self.rm_router_logits:
if self.enable_super_kernel:
router_logits, _ = gate(hidden_states.float())
else:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
TorchairAscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
if shared_experts:
if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2:
@@ -1305,6 +1316,8 @@ class TorchairAscendFusedMoE(FusedMoE):
mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
prefix=self.prefix,
running_in_super_kernel=running_in_super_kernel,
)
if shared_experts: