[feat][torchair] support super kernel feat for quantized dsr1 (#3485)

### What this PR does / why we need it?
Port #1916 and #2157 to master branch to fuse operators in deepseek moe
layers, which can reduce scheduling overhead on devices. Note that this
feature is valid only when `tp_size = 1` and
`multistream_overlap_shared_expert` is enabled with torchair graph mode.

### Does this PR introduce _any_ user-facing change?
Users can enable this feature with `--additional-config
'{"torchair_graph_config":{"enabled":true, "enable_super_kernel":true},
"multistream_overlap_shared_expert":true}'`.

### How was this patch tested?
E2E deepseek serving with 2P1D disaggregated prefill scenarios.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2025-10-20 20:04:37 +08:00
committed by GitHub
parent 70bef33f13
commit 068ed706c8
8 changed files with 138 additions and 86 deletions

View File

@@ -58,6 +58,7 @@ The details of each config option are as follows:
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `enable_super_kernel` | bool | `False` | Whether to enable super kernel to fuse operators in deepseek moe layers. This option only takes effects on moe models using dynamic w8a8 quantization.|
**ascend_scheduler_config**

View File

@@ -56,17 +56,18 @@ class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
self.layer = mock.MagicMock()
mock_dtype = mock.PropertyMock(return_value=torch.float16)
type(self.layer.weight.data).dtype = mock_dtype
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_enable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(layer)
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@@ -74,23 +75,19 @@ class TestAscendUnquantizedLinearMethod(TestBase):
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_disable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(layer)
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch.version")
def test_process_weights_after_loading_not_8_3(self, mock_version,
mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.2.RC1"
mock_is_nz.return_value = 1
# Should not raise exception
self.method.process_weights_after_loading(layer)
self.method.process_weights_after_loading(self.layer)
class TestAscendRowParallelLinear(BaseLinearTest):

View File

@@ -37,7 +37,8 @@ class AscendConfig:
torchair_graph_config = additional_config.get("torchair_graph_config",
{})
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
self.torchair_graph_config = TorchairGraphConfig(
torchair_graph_config, vllm_config, additional_config)
ascend_scheduler_config = additional_config.get(
"ascend_scheduler_config", {})
@@ -133,7 +134,7 @@ class TorchairGraphConfig:
Configuration Object for torchair_graph_config from additional_config
"""
def __init__(self, torchair_graph_config):
def __init__(self, torchair_graph_config, vllm_config, additional_config):
self.enabled = torchair_graph_config.get("enabled", False)
self.mode = torchair_graph_config.get("mode", '')
self.use_cached_graph = torchair_graph_config.get(
@@ -151,6 +152,8 @@ class TorchairGraphConfig:
self.enable_frozen_parameter = torchair_graph_config.get(
"enable_frozen_parameter", True)
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
self.enable_super_kernel = torchair_graph_config.get(
"enable_super_kernel", False)
if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
@@ -186,6 +189,20 @@ class TorchairGraphConfig:
raise RuntimeError(
"enable_kv_nz is valid only when Torchair graph mode is enabled"
)
if self.enable_super_kernel:
raise RuntimeError(
"enable_super_kernel is valid only when Torchair graph mode is enabled"
)
if self.enable_super_kernel:
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise RuntimeError(
"enable_super_kernel is valid only when tensor_parallel_size is 1"
)
if not additional_config.get("multistream_overlap_shared_expert",
False):
raise RuntimeError(
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
)
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
raise RuntimeError(
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"

View File

@@ -44,7 +44,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if is_enable_nz() and torch.version.cann.startswith("8.3"):
if (is_enable_nz() and torch.version.cann.startswith("8.3") and
layer.weight.data.dtype in [torch.float16, torch.bfloat16]):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)

View File

@@ -328,14 +328,22 @@ class TorchairDeepseekV2MoE(nn.Module):
ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel
self.params_dtype = torch.float32 if self.enable_super_kernel else \
torch.get_default_dtype()
# Converting gate weight to fp32 is to adapt to the super kernel feature.
# Super kernel feature currently cannot fuse operators such as cast, stridedslice, and add.
# In the moe stage, Cast will interrupt the fusion of the super kernel. To avoid this problem,
# modifications will be made in the initialization stage.
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=self.params_dtype,
prefix=f"{prefix}.gate")
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts))
torch.empty(config.n_routed_experts, dtype=self.params_dtype))
else:
self.gate.e_score_correction_bias = None

View File

@@ -48,7 +48,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
super_kernel)
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
get_ascend_soc_version,
@@ -990,6 +991,7 @@ class TorchairAscendFusedMoE(FusedMoE):
)
TorchairAscendFusedMoE.moe_counter += 1
self.moe_instance_id = TorchairAscendFusedMoE.moe_counter
self.prefix = prefix
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -1096,6 +1098,7 @@ class TorchairAscendFusedMoE(FusedMoE):
self.multistream_overlap_shared_expert = \
ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
@@ -1192,16 +1195,24 @@ class TorchairAscendFusedMoE(FusedMoE):
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicFusedMoEMethod
running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2
if self.multistream_overlap_shared_expert:
if not self.rm_router_logits:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
TorchairAscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
with super_kernel(self.prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
if not self.rm_router_logits:
if self.enable_super_kernel:
router_logits, _ = gate(hidden_states.float())
else:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
TorchairAscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
if shared_experts:
if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2:
@@ -1305,6 +1316,8 @@ class TorchairAscendFusedMoE(FusedMoE):
mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
prefix=self.prefix,
running_in_super_kernel=running_in_super_kernel,
)
if shared_experts:

View File

@@ -26,7 +26,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
super_kernel)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version,
is_enable_nz,
@@ -927,6 +928,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
prefix: str = "",
running_in_super_kernel: bool = False,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
@@ -934,55 +937,59 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
with super_kernel(prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
topk_weights = topk_weights.to(x.dtype)
if fused_moe_state == FusedMoEState.AllGatherEP:
return torchair_fused_experts_with_allgather(
hidden_states=x,
@@ -995,25 +1002,28 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
top_k=top_k,
expert_map=expert_map)
elif fused_moe_state == FusedMoEState.MC2:
return torchair_fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None),
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
dynamic_eplb=self.dynamic_eplb)
with super_kernel(prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
return torchair_fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None),
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
dynamic_eplb=self.dynamic_eplb)
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
import torch
import torch_npu
from torchair.scope import super_kernel as _super_kernel
try:
# Recent release of torchair has moved these ops to `.scope`.
@@ -231,3 +232,7 @@ def torchair_ops_patch():
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
def super_kernel(prefix: str, option: str, enabled: bool = True):
return _super_kernel(prefix, option) if enabled else nullcontext()