[main][Prefill Perf] Optimize Quantized MoE Performance by Reducing All2All Communication (#2195)

This PR significantly optimizes performance for quantized Mixture of
Experts (MoE) layers by changing the order of quantization and
communication operations.

In the previous implementation, the `all2all` operation was performed on
unquantized `hidden_states` (in FP16/BF16) *before* quantization,
resulting in substantial communication overhead. By performing
quantization on each EP rank **first** and then sending the much smaller
quantized data, we reduce the communication volume by nearly 50%.

Additionally, this PR includes a minor optimization to cast `int` inputs
to `float` for the `argsort` operation, forcing it to run on a faster
NPU core instead of the AICPU.

These changes lead to a clear and significant performance gain in MoE
quantization scenarios.

- vLLM version: v0.10.0
- vLLM main:
7175817637

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Slightwind
2025-08-05 18:47:13 +08:00
committed by GitHub
parent 292fb8f696
commit f3b50c54e8
2 changed files with 151 additions and 43 deletions

View File

@@ -0,0 +1,75 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)
@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing")
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu,
mock_grouped_matmul,
mock_moe_re_routing,
mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_moe_init_routing.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
)
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder
result = fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))

View File

@@ -334,6 +334,29 @@ def fused_experts_with_mc2(
return hidden_states, shared_output
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
num_tokens, _ = hidden_states.shape
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
1, 0).contiguous().view(-1))
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
global_expert_tokens = global_expert_tokens.to(torch.int32)
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states)
return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales
# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
@@ -358,50 +381,54 @@ def fused_experts_with_all2all(
num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
device = hidden_states.device
if expert_map is not None:
global_num_experts = len(expert_map) + global_redundant_expert_num
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
hidden_states,
expert_idx=topk_ids.to(torch.int32),
active_num=0,
expert_capacity=0,
expert_num=global_num_experts,
drop_pad_mode=0,
expert_tokens_num_mode=2,
expert_tokens_before_capacity_flag=False,
quant_mode=1,
)
else:
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
hidden_states, top_k, topk_ids, global_num_experts)
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
-1).sum(-1)
gather_sizes = global_expert_tokens.new_empty(
global_expert_tokens.shape[0])
dist.all_to_all_single(gather_sizes, global_expert_tokens)
gather_sizes = torch.empty_like(scatter_sizes)
dist.all_to_all_single(gather_sizes,
scatter_sizes,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
gather_size_list = gather_sizes.cpu().tolist()
token_counts_combined = torch.stack(
[gather_sizes, global_expert_tokens], dim=0)
token_counts_combined = token_counts_combined.view(
2, ep_group.world_size, -1).sum(dim=2)
token_counts_combined_cpu = token_counts_combined.to(
torch.device("cpu"), non_blocking=True).numpy()
all_tokens = gather_sizes.sum()
expanded_expert_idx = expanded_expert_idx % local_num_experts
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
scatter_size_list,
gather_size_list)
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
scatter_size_list,
gather_size_list)
gathered_tokens = quantized_tokens.new_empty(all_tokens.item(),
quantized_tokens.shape[1])
dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0])
gather_size_list = token_counts_combined_cpu[1]
scatter_size_list = token_counts_combined_cpu[0]
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
dist.all_to_all_single(gathered_tokens, quantized_tokens,
scatter_size_list, gather_size_list)
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
gather_size_list)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
group_list_type = 0
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
gathered_tokens,
gather_sizes.view(ep_group.world_size, -1),
per_token_scales=dynamic_scale)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
@@ -419,6 +446,7 @@ def fused_experts_with_all2all(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
dynamic_scale = None
# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(
@@ -428,14 +456,19 @@ def fused_experts_with_all2all(
w2,
w2_scale,
expert_tokens, #16
dynamic_scale=dynamic_scale,
group_list_type=group_list_type)
if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
hidden_states = hidden_states[resorted_idx]
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
gather_size_list,
scatter_size_list)
reordered_outputs = torch.index_select(
hidden_states,
dim=0,
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
dist.all_to_all_single(hidden_states, reordered_outputs,
gather_size_list, scatter_size_list)
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
@@ -444,8 +477,8 @@ def fused_experts_with_all2all(
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
export_for_source_row=None,
drop_pad_mode=2)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.