From 79a910ef4730d3f1be14496a1681eee2566f64a0 Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Thu, 18 Sep 2025 17:35:04 +0800 Subject: [PATCH] [bugfix][torchair] fix multistream_moe problems in torchair graph mode (#2681) This pr fixes two problems while `multistream_moe` enabled in torchair graph mode: 1. check `TorchairAscendW8A8DynamicFusedMoEMethod` instead of incorrect `AscendW8A8DynamicFusedMoEMethod` 2. mc2_mask should be chunked no matter `replace_allreduce` is True or False in forward function of `TorchairAscendFusedMoE` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/0fb2551c238c7ccbcf6f25ef4646ce6c92f684d1 Signed-off-by: linfeng-yuan <1102311262@qq.com> --- .../torchair/ops/torchair_fused_moe.py | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 2221130..6350fbb 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -1146,14 +1146,14 @@ class TorchairAscendFusedMoE(FusedMoE): mc2_mask = forward_context.mc2_mask # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None - from vllm_ascend.quantization.w8a8_dynamic import \ - AscendW8A8DynamicFusedMoEMethod + from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ + TorchairAscendW8A8DynamicFusedMoEMethod if self.enable_multistream_moe: if not self.rm_router_logits: router_logits, _ = gate(hidden_states) if hasattr(self.quant_method, "quant_method") and \ isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod + TorchairAscendW8A8DynamicFusedMoEMethod ) and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0): quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( @@ -1178,31 +1178,33 @@ class TorchairAscendFusedMoE(FusedMoE): if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast - ] and not replace_allreduce): - if fused_moe_state in {FusedMoEState.MC2}: - padding_size = forward_context.padded_num_tokens - else: - # TODO: Determine if we can remove the padding - padding_size = tp_size - if num_tokens < padding_size and not self.enable_shared_expert_dp: - hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, padding_size - num_tokens)) - router_logits = nn.functional.pad( - router_logits, (0, 0, 0, padding_size - num_tokens)) + ]): if tp_size > 1: tp_rank = get_tensor_model_parallel_rank() - if not self.enable_shared_expert_dp: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] - chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) mc2_mask = chunk_mc2_mask[tp_rank] + if not replace_allreduce: + if fused_moe_state in {FusedMoEState.MC2}: + padding_size = forward_context.padded_num_tokens + else: + # TODO: Determine if we can remove the padding + padding_size = tp_size + if num_tokens < padding_size and not self.enable_shared_expert_dp: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, padding_size - num_tokens)) + router_logits = nn.functional.pad( + router_logits, (0, 0, 0, padding_size - num_tokens)) + if tp_size > 1: + tp_rank = get_tensor_model_parallel_rank() + if not self.enable_shared_expert_dp: + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + hidden_states = chunk_hidden_states[tp_rank] + router_logits = chunk_router_logits[tp_rank] if self.dp_size > 1: if fused_moe_state == FusedMoEState.AllGather: