From 5d8ec280090b4a7567fb2b50a7cedda44902c37f Mon Sep 17 00:00:00 2001 From: Wang Yixuan <88923622+hust17yixuan@users.noreply.github.com> Date: Tue, 26 Aug 2025 14:12:43 +0800 Subject: [PATCH] [2/N][refactor] split torchair from fused_moe (#2503) ### What this PR does / why we need it? After moved torchair related fused_moe section into torchair_fused_moe, split the torchair from the origin fused_moe ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? vLLM version: main vLLM main: https://github.com/vllm-project/vllm/commit/ab9f2cfd1942f7ddfee658ce86ea96b4789862af - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/2a97ffc33de097f267f217132ced42f4714b7de5 Signed-off-by: hust17yixuan <303660421@qq.com> --- vllm_ascend/ops/fused_moe.py | 63 +++++++++--------------------------- 1 file changed, 16 insertions(+), 47 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 0d6dc9c..611935c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -50,7 +50,6 @@ from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, get_ascend_soc_version, @@ -76,8 +75,6 @@ def unified_fused_experts( w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, moe_comm_method: Optional[MoECommMethod] = None, - # For TorchAir graph - is_torchair: bool = False, # For Cube/Vector parallel shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, @@ -191,7 +188,6 @@ def fused_experts_with_mc2( expert_map: torch.Tensor = None, moe_all_to_all_group_name: Optional[str] = None, shared_experts: Optional[Any] = None, - is_torchair: bool = False, mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: quant_mode = 0 @@ -199,8 +195,7 @@ def fused_experts_with_mc2( ep_world_size = moe_parallel_config.ep_size # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 - or is_torchair) + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3) # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 @@ -246,11 +241,8 @@ def fused_experts_with_mc2( 0:5] if shared_experts is not None: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(hidden_states, topk_weights) - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - npu_wait_tensor(shared_gate_up, expand_x) - shared_act = shared_experts.act_fn(shared_gate_up) + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) + shared_act = shared_experts.act_fn(shared_gate_up) w1 = w1.transpose(1, 2) @@ -324,9 +316,7 @@ def fused_experts_with_mc2( if shared_experts is None: return hidden_states else: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_act, down_out_list) - shared_hidden_states, _ = shared_experts.down_proj(shared_act) + shared_hidden_states, _ = shared_experts.down_proj(shared_act) return hidden_states, shared_hidden_states @@ -930,9 +920,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.max_model_len = vllm_config.model_config.max_model_len - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + get_ascend_config() try: device_group = get_mc2_group().device_group @@ -1169,10 +1157,6 @@ class AscendFusedMoE(FusedMoE): self.ep_size, get_ep_group().rank_in_group, self.global_num_experts) - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe and \ - self.torchair_graph_enabled self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp if self.scoring_func != "softmax" and not self.use_grouped_topk: @@ -1278,23 +1262,10 @@ class AscendFusedMoE(FusedMoE): mc2_mask = forward_context.mc2_mask # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None - from vllm_ascend.quantization.w8a8_dynamic import \ - AscendW8A8DynamicFusedMoEMethod - if self.enable_multistream_moe: - if not self.rm_router_logits: - router_logits, _ = gate(hidden_states) - if hasattr(self.quant_method, "quant_method") and \ - isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod - ) and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( - hidden_states) if shared_experts: - if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: - # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce - shared_hidden_states = shared_experts(hidden_states) + # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce + shared_hidden_states = shared_experts(hidden_states) mc2_mask = forward_context.mc2_mask @@ -1339,16 +1310,15 @@ class AscendFusedMoE(FusedMoE): if self.dp_size > 1: if fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 - if not self.torchair_graph_enabled: - max_tokens_across_dp = forward_context.max_tokens_across_dp - if num_tokens < max_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, + max_tokens_across_dp = forward_context.max_tokens_across_dp + if num_tokens < max_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_tokens_across_dp - num_tokens)) + if not self.rm_router_logits: + router_logits = nn.functional.pad( + router_logits, (0, 0, 0, max_tokens_across_dp - num_tokens)) - if not self.rm_router_logits: - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) if self.rm_router_logits: router_logits, _ = gate(hidden_states) @@ -1385,8 +1355,7 @@ class AscendFusedMoE(FusedMoE): enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, - shared_experts=shared_experts if self.torchair_graph_enabled - and self.enable_multistream_moe and not is_prefill else None, + shared_experts=None, mc2_mask=mc2_mask, token_dispatcher=self.token_dispatcher, quantized_x_for_share=quantized_x_for_share,