[Perf] [MoE] optimize all2allv (#3738)
### What this PR does / why we need it?
1. Replace init_routing_v2 with token_permute to optimize performance.
Note: This pr will be merged after switching ci to CANN 8.3
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
vllm bench serve bs = 48 / rr = 10000 / 2k input -> 20k output:
before:
<img width="489" height="488" alt="image"
src="https://github.com/user-attachments/assets/268a19e6-9ab2-47f0-84a1-4f6d3bc342e2"
/>
after:
<img width="480" height="500" alt="image"
src="https://github.com/user-attachments/assets/d9b1e628-0520-42d5-8a21-b42f7cd7abc7"
/>
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -701,33 +701,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
if self.with_quant:
|
||||
assert global_input_tokens_local_experts_indices is not None, \
|
||||
"global_input_tokens_local_experts_indices must be provided"
|
||||
expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze(
|
||||
dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute(
|
||||
dynamic_scale_after_all2all.unsqueeze(-1),
|
||||
global_input_tokens_local_experts_indices)
|
||||
dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(
|
||||
-1)
|
||||
active_num = global_input_tokens_local_experts_indices.numel()
|
||||
|
||||
if active_num <= 0:
|
||||
reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices
|
||||
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
|
||||
|
||||
global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
global_input_tokens,
|
||||
expert_idx_2d,
|
||||
scale=dynamic_scale_after_all2all,
|
||||
active_num=active_num,
|
||||
expert_capacity=0,
|
||||
expert_num=self.num_local_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[0, self.num_local_experts],
|
||||
quant_mode=-1,
|
||||
row_idx_type=0,
|
||||
)
|
||||
return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping
|
||||
|
||||
# Non-quantized case
|
||||
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens, global_input_tokens_local_experts_indices)
|
||||
return global_input_tokens, None, reversed_global_input_permutation_mapping
|
||||
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
|
||||
|
||||
def _combine_preprocess(self, hidden_states: torch.Tensor,
|
||||
context_metadata: dict) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user