From 17259cb265da85cd5f14ac87887b4a17047be61e Mon Sep 17 00:00:00 2001 From: weichen Date: Thu, 13 Nov 2025 09:38:11 +0800 Subject: [PATCH] [Perf] [MoE] optimize all2allv (#3738) ### What this PR does / why we need it? 1. Replace init_routing_v2 with token_permute to optimize performance. Note: This pr will be merged after switching ci to CANN 8.3 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vllm bench serve bs = 48 / rr = 10000 / 2k input -> 20k output: before: image after: image - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: Pr0Wh1teGivee --- vllm_ascend/ops/fused_moe/token_dispatcher.py | 27 ++++--------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index c6bdfe4d..1ef06533 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -701,33 +701,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): if self.with_quant: assert global_input_tokens_local_experts_indices is not None, \ "global_input_tokens_local_experts_indices must be provided" - expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze( + dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute( + dynamic_scale_after_all2all.unsqueeze(-1), + global_input_tokens_local_experts_indices) + dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze( -1) - active_num = global_input_tokens_local_experts_indices.numel() - - if active_num <= 0: - reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices - return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping - - global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( - global_input_tokens, - expert_idx_2d, - scale=dynamic_scale_after_all2all, - active_num=active_num, - expert_capacity=0, - expert_num=self.num_local_experts, - expert_tokens_num_type=1, - expert_tokens_num_flag=True, - active_expert_range=[0, self.num_local_experts], - quant_mode=-1, - row_idx_type=0, - ) - return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping # Non-quantized case global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( global_input_tokens, global_input_tokens_local_experts_indices) - return global_input_tokens, None, reversed_global_input_permutation_mapping + return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor: