[Perf] [MoE] optimize all2allv (#3738)

### What this PR does / why we need it?
1. Replace init_routing_v2 with token_permute to optimize performance.

Note: This pr will be merged after switching ci to CANN 8.3
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
vllm bench serve bs = 48 / rr = 10000 / 2k input -> 20k output:
before:
<img width="489" height="488" alt="image"
src="https://github.com/user-attachments/assets/268a19e6-9ab2-47f0-84a1-4f6d3bc342e2"
/>
 after:
<img width="480" height="500" alt="image"
src="https://github.com/user-attachments/assets/d9b1e628-0520-42d5-8a21-b42f7cd7abc7"
/>
- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-11-13 09:38:11 +08:00
committed by GitHub
parent 6bc770cd78
commit 17259cb265

View File

@@ -701,33 +701,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
if self.with_quant:
assert global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be provided"
expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze(
dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute(
dynamic_scale_after_all2all.unsqueeze(-1),
global_input_tokens_local_experts_indices)
dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(
-1)
active_num = global_input_tokens_local_experts_indices.numel()
if active_num <= 0:
reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens,
expert_idx_2d,
scale=dynamic_scale_after_all2all,
active_num=active_num,
expert_capacity=0,
expert_num=self.num_local_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts],
quant_mode=-1,
row_idx_type=0,
)
return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping
# Non-quantized case
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens, global_input_tokens_local_experts_indices)
return global_input_tokens, None, reversed_global_input_permutation_mapping
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states: torch.Tensor,
context_metadata: dict) -> torch.Tensor: