[main][Prefill Perf] Optimize Quantized MoE Performance by Reducing All2All Communication (#2195)

This PR significantly optimizes performance for quantized Mixture of
Experts (MoE) layers by changing the order of quantization and
communication operations.

In the previous implementation, the `all2all` operation was performed on
unquantized `hidden_states` (in FP16/BF16) *before* quantization,
resulting in substantial communication overhead. By performing
quantization on each EP rank **first** and then sending the much smaller
quantized data, we reduce the communication volume by nearly 50%.

Additionally, this PR includes a minor optimization to cast `int` inputs
to `float` for the `argsort` operation, forcing it to run on a faster
NPU core instead of the AICPU.

These changes lead to a clear and significant performance gain in MoE
quantization scenarios.

- vLLM version: v0.10.0
- vLLM main:
7175817637

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Slightwind
2025-08-05 18:47:13 +08:00
committed by GitHub
parent 292fb8f696
commit f3b50c54e8
2 changed files with 151 additions and 43 deletions

View File

@@ -334,6 +334,29 @@ def fused_experts_with_mc2(
return hidden_states, shared_output
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
num_tokens, _ = hidden_states.shape
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
1, 0).contiguous().view(-1))
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
global_expert_tokens = global_expert_tokens.to(torch.int32)
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states)
return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales
# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
@@ -358,50 +381,54 @@ def fused_experts_with_all2all(
num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
device = hidden_states.device
if expert_map is not None:
global_num_experts = len(expert_map) + global_redundant_expert_num
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
hidden_states,
expert_idx=topk_ids.to(torch.int32),
active_num=0,
expert_capacity=0,
expert_num=global_num_experts,
drop_pad_mode=0,
expert_tokens_num_mode=2,
expert_tokens_before_capacity_flag=False,
quant_mode=1,
)
else:
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
hidden_states, top_k, topk_ids, global_num_experts)
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
-1).sum(-1)
gather_sizes = global_expert_tokens.new_empty(
global_expert_tokens.shape[0])
dist.all_to_all_single(gather_sizes, global_expert_tokens)
gather_sizes = torch.empty_like(scatter_sizes)
dist.all_to_all_single(gather_sizes,
scatter_sizes,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
gather_size_list = gather_sizes.cpu().tolist()
token_counts_combined = torch.stack(
[gather_sizes, global_expert_tokens], dim=0)
token_counts_combined = token_counts_combined.view(
2, ep_group.world_size, -1).sum(dim=2)
token_counts_combined_cpu = token_counts_combined.to(
torch.device("cpu"), non_blocking=True).numpy()
all_tokens = gather_sizes.sum()
expanded_expert_idx = expanded_expert_idx % local_num_experts
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
scatter_size_list,
gather_size_list)
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
scatter_size_list,
gather_size_list)
gathered_tokens = quantized_tokens.new_empty(all_tokens.item(),
quantized_tokens.shape[1])
dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0])
gather_size_list = token_counts_combined_cpu[1]
scatter_size_list = token_counts_combined_cpu[0]
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
dist.all_to_all_single(gathered_tokens, quantized_tokens,
scatter_size_list, gather_size_list)
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
gather_size_list)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
group_list_type = 0
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
gathered_tokens,
gather_sizes.view(ep_group.world_size, -1),
per_token_scales=dynamic_scale)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
@@ -419,6 +446,7 @@ def fused_experts_with_all2all(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
dynamic_scale = None
# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(
@@ -428,14 +456,19 @@ def fused_experts_with_all2all(
w2,
w2_scale,
expert_tokens, #16
dynamic_scale=dynamic_scale,
group_list_type=group_list_type)
if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
hidden_states = hidden_states[resorted_idx]
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
gather_size_list,
scatter_size_list)
reordered_outputs = torch.index_select(
hidden_states,
dim=0,
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
dist.all_to_all_single(hidden_states, reordered_outputs,
gather_size_list, scatter_size_list)
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
@@ -444,8 +477,8 @@ def fused_experts_with_all2all(
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
export_for_source_row=None,
drop_pad_mode=2)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.