diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index fa2e4fc..7df2c48 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -58,6 +58,7 @@ The details of each config option are as follows: | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | | `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | +| `enable_super_kernel` | bool | `False` | Whether to enable super kernel to fuse operators in deepseek moe layers. This option only takes effects on moe models using dynamic w8a8 quantization.| **ascend_scheduler_config** diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index e2b0eff..1153bfe 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -56,17 +56,18 @@ class TestAscendUnquantizedLinearMethod(TestBase): def setUp(self): self.method = AscendUnquantizedLinearMethod() + self.layer = mock.MagicMock() + mock_dtype = mock.PropertyMock(return_value=torch.float16) + type(self.layer.weight.data).dtype = mock_dtype @mock.patch("vllm_ascend.ops.linear.is_enable_nz") @mock.patch("torch_npu.npu_format_cast") @mock.patch("torch.version") def test_process_weights_after_loading_is_8_3_enable_nz( self, mock_version, mock_format_cast, mock_is_nz): - layer = mock.MagicMock() - mock_version.cann = "8.3.RC1" mock_is_nz.return_value = 1 - self.method.process_weights_after_loading(layer) + self.method.process_weights_after_loading(self.layer) mock_format_cast.assert_called_once() @mock.patch("vllm_ascend.ops.linear.is_enable_nz") @@ -74,23 +75,19 @@ class TestAscendUnquantizedLinearMethod(TestBase): @mock.patch("torch.version") def test_process_weights_after_loading_is_8_3_disable_nz( self, mock_version, mock_format_cast, mock_is_nz): - layer = mock.MagicMock() - mock_version.cann = "8.3.RC1" mock_is_nz.return_value = 0 - self.method.process_weights_after_loading(layer) + self.method.process_weights_after_loading(self.layer) mock_format_cast.assert_not_called() @mock.patch("vllm_ascend.ops.linear.is_enable_nz") @mock.patch("torch.version") def test_process_weights_after_loading_not_8_3(self, mock_version, mock_is_nz): - layer = mock.MagicMock() - mock_version.cann = "8.2.RC1" mock_is_nz.return_value = 1 # Should not raise exception - self.method.process_weights_after_loading(layer) + self.method.process_weights_after_loading(self.layer) class TestAscendRowParallelLinear(BaseLinearTest): diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index a265e96..9f43b2d 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -37,7 +37,8 @@ class AscendConfig: torchair_graph_config = additional_config.get("torchair_graph_config", {}) - self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config) + self.torchair_graph_config = TorchairGraphConfig( + torchair_graph_config, vllm_config, additional_config) ascend_scheduler_config = additional_config.get( "ascend_scheduler_config", {}) @@ -133,7 +134,7 @@ class TorchairGraphConfig: Configuration Object for torchair_graph_config from additional_config """ - def __init__(self, torchair_graph_config): + def __init__(self, torchair_graph_config, vllm_config, additional_config): self.enabled = torchair_graph_config.get("enabled", False) self.mode = torchair_graph_config.get("mode", '') self.use_cached_graph = torchair_graph_config.get( @@ -151,6 +152,8 @@ class TorchairGraphConfig: self.enable_frozen_parameter = torchair_graph_config.get( "enable_frozen_parameter", True) self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) + self.enable_super_kernel = torchair_graph_config.get( + "enable_super_kernel", False) if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") @@ -186,6 +189,20 @@ class TorchairGraphConfig: raise RuntimeError( "enable_kv_nz is valid only when Torchair graph mode is enabled" ) + if self.enable_super_kernel: + raise RuntimeError( + "enable_super_kernel is valid only when Torchair graph mode is enabled" + ) + if self.enable_super_kernel: + if vllm_config.parallel_config.tensor_parallel_size != 1: + raise RuntimeError( + "enable_super_kernel is valid only when tensor_parallel_size is 1" + ) + if not additional_config.get("multistream_overlap_shared_expert", + False): + raise RuntimeError( + "enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled" + ) if self.use_cached_kv_cache_bytes and not self.use_cached_graph: raise RuntimeError( "use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled" diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 665ac74..81d7d9e 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -44,7 +44,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - if is_enable_nz() and torch.version.cann.startswith("8.3"): + if (is_enable_nz() and torch.version.cann.startswith("8.3") and + layer.weight.data.dtype in [torch.float16, torch.bfloat16]): layer.weight.data = torch_npu.npu_format_cast( layer.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 7f09c52..8257a09 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -328,14 +328,22 @@ class TorchairDeepseekV2MoE(nn.Module): ascend_config.multistream_overlap_shared_expert and \ self.torchair_graph_enabled + self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel + self.params_dtype = torch.float32 if self.enable_super_kernel else \ + torch.get_default_dtype() + # Converting gate weight to fp32 is to adapt to the super kernel feature. + # Super kernel feature currently cannot fuse operators such as cast, stridedslice, and add. + # In the moe stage, Cast will interrupt the fusion of the super kernel. To avoid this problem, + # modifications will be made in the initialization stage. self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, bias=False, quant_config=None, + params_dtype=self.params_dtype, prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) + torch.empty(config.n_routed_experts, dtype=self.params_dtype)) else: self.gate.e_score_correction_bias = None diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 72f8cb7..9a07e8c 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -48,7 +48,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, + super_kernel) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, get_ascend_soc_version, @@ -990,6 +991,7 @@ class TorchairAscendFusedMoE(FusedMoE): ) TorchairAscendFusedMoE.moe_counter += 1 self.moe_instance_id = TorchairAscendFusedMoE.moe_counter + self.prefix = prefix if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -1096,6 +1098,7 @@ class TorchairAscendFusedMoE(FusedMoE): self.multistream_overlap_shared_expert = \ ascend_config.multistream_overlap_shared_expert and \ self.torchair_graph_enabled + self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp if self.scoring_func != "softmax" and not self.use_grouped_topk: @@ -1192,16 +1195,24 @@ class TorchairAscendFusedMoE(FusedMoE): quantized_x_for_share, dynamic_scale_for_share = None, None from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ TorchairAscendW8A8DynamicFusedMoEMethod + running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2 + if self.multistream_overlap_shared_expert: - if not self.rm_router_logits: - router_logits, _ = gate(hidden_states) - if hasattr(self.quant_method, "quant_method") and \ - isinstance(self.quant_method.quant_method, - TorchairAscendW8A8DynamicFusedMoEMethod - ) and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( - hidden_states) + with super_kernel(self.prefix, + "stream-fusion=1", + enabled=running_in_super_kernel): + if not self.rm_router_logits: + if self.enable_super_kernel: + router_logits, _ = gate(hidden_states.float()) + else: + router_logits, _ = gate(hidden_states) + if hasattr(self.quant_method, "quant_method") and \ + isinstance(self.quant_method.quant_method, + TorchairAscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( + hidden_states) if shared_experts: if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2: @@ -1305,6 +1316,8 @@ class TorchairAscendFusedMoE(FusedMoE): mc2_mask=mc2_mask, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, + prefix=self.prefix, + running_in_super_kernel=running_in_super_kernel, ) if shared_experts: diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 5bd622e..23d59a8 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -26,7 +26,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts -from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, + super_kernel) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, dispose_tensor, get_ascend_soc_version, is_enable_nz, @@ -927,6 +928,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + prefix: str = "", + running_in_super_kernel: bool = False, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -934,55 +937,59 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - fused_moe_state = get_forward_context().fused_moe_state if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: fused_moe_state = FusedMoEState.All2All shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(quantized_x_for_share, router_logits) - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - # this is a naive implementation for experts load balance so as - # to avoid accumulating too much tokens on a single rank. - # currently it is only activated when doing profile runs. - if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + with super_kernel(prefix, + "stream-fusion=1", + enabled=running_in_super_kernel): + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if is_deepseek_v3_r1: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = torchair_select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + topk_weights = topk_weights.to(x.dtype) - topk_weights = topk_weights.to(x.dtype) if fused_moe_state == FusedMoEState.AllGatherEP: return torchair_fused_experts_with_allgather( hidden_states=x, @@ -995,25 +1002,28 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: top_k=top_k, expert_map=expert_map) elif fused_moe_state == FusedMoEState.MC2: - return torchair_fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled, - mc2_mask=kwargs.get("mc2_mask", None), - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - dynamic_eplb=self.dynamic_eplb) + with super_kernel(prefix, + "stream-fusion=1", + enabled=running_in_super_kernel): + return torchair_fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_fp32, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + mc2_mask=kwargs.get("mc2_mask", None), + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale, + dynamic_eplb=self.dynamic_eplb) elif fused_moe_state in [ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast ]: diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 164a620..211d738 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import torch import torch_npu +from torchair.scope import super_kernel as _super_kernel try: # Recent release of torchair has moved these ops to `.scope`. @@ -231,3 +232,7 @@ def torchair_ops_patch(): AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign] AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign] + + +def super_kernel(prefix: str, option: str, enabled: bool = True): + return _super_kernel(prefix, option) if enabled else nullcontext() \ No newline at end of file