[main][Prefill Perf] Optimize Quantized MoE Performance by Reducing All2All Communication (#2195)
This PR significantly optimizes performance for quantized Mixture of
Experts (MoE) layers by changing the order of quantization and
communication operations.
In the previous implementation, the `all2all` operation was performed on
unquantized `hidden_states` (in FP16/BF16) *before* quantization,
resulting in substantial communication overhead. By performing
quantization on each EP rank **first** and then sending the much smaller
quantized data, we reduce the communication volume by nearly 50%.
Additionally, this PR includes a minor optimization to cast `int` inputs
to `float` for the `argsort` operation, forcing it to run on a faster
NPU core instead of the AICPU.
These changes lead to a clear and significant performance gain in MoE
quantization scenarios.
- vLLM version: v0.10.0
- vLLM main:
7175817637
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
75
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
75
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.hidden_size = 128
|
||||
self.num_tokens = 128
|
||||
self.placeholder = torch.randn(self.num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch("torch_npu.npu_moe_re_routing")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
@patch("torch_npu.npu_moe_finalize_routing")
|
||||
@patch("torch_npu.npu_moe_init_routing")
|
||||
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
|
||||
mock_moe_finalize_routing,
|
||||
mock_dynamic_quant, mock_swiglu,
|
||||
mock_grouped_matmul,
|
||||
mock_moe_re_routing,
|
||||
mock_all_to_all_single):
|
||||
expert_map = MagicMock()
|
||||
ep_group = MagicMock()
|
||||
placeholder_int8 = torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, self.hidden_size),
|
||||
dtype=torch.int8)
|
||||
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
mock_moe_init_routing.return_value = (
|
||||
placeholder_int8,
|
||||
placeholder_ones,
|
||||
placeholder_ones,
|
||||
)
|
||||
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
||||
torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, ),
|
||||
dtype=torch.int32),
|
||||
self.placeholder)
|
||||
mock_grouped_matmul.return_value = self.placeholder
|
||||
mock_swiglu.return_value = self.placeholder
|
||||
mock_dynamic_quant.return_value = (
|
||||
placeholder_int8,
|
||||
torch.randn(self.num_tokens),
|
||||
)
|
||||
mock_moe_finalize_routing.return_value = self.placeholder
|
||||
|
||||
result = fused_experts_with_all2all(
|
||||
hidden_states=self.placeholder,
|
||||
w1=self.placeholder,
|
||||
w1_scale=self.placeholder,
|
||||
w2=self.placeholder,
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
top_k=8,
|
||||
expert_map=expert_map,
|
||||
ep_group=ep_group,
|
||||
log2phy=None,
|
||||
global_redundant_expert_num=256,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
self.assertEqual(result.shape, (128, 128))
|
||||
@@ -334,6 +334,29 @@ def fused_experts_with_mc2(
|
||||
return hidden_states, shared_output
|
||||
|
||||
|
||||
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
|
||||
num_tokens, _ = hidden_states.shape
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
|
||||
1, 0).contiguous().view(-1))
|
||||
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
||||
minlength=global_num_experts)
|
||||
global_expert_tokens = global_expert_tokens.to(torch.int32)
|
||||
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all(
|
||||
@@ -358,50 +381,54 @@ def fused_experts_with_all2all(
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
num_experts = w1.shape[0]
|
||||
device = hidden_states.device
|
||||
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=device).view(top_k, -1).permute(
|
||||
1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
||||
hidden_states,
|
||||
expert_idx=topk_ids.to(torch.int32),
|
||||
active_num=0,
|
||||
expert_capacity=0,
|
||||
expert_num=global_num_experts,
|
||||
drop_pad_mode=0,
|
||||
expert_tokens_num_mode=2,
|
||||
expert_tokens_before_capacity_flag=False,
|
||||
quant_mode=1,
|
||||
)
|
||||
else:
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
|
||||
hidden_states, top_k, topk_ids, global_num_experts)
|
||||
|
||||
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
||||
minlength=global_num_experts)
|
||||
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
||||
-1).sum(-1)
|
||||
gather_sizes = global_expert_tokens.new_empty(
|
||||
global_expert_tokens.shape[0])
|
||||
dist.all_to_all_single(gather_sizes, global_expert_tokens)
|
||||
|
||||
gather_sizes = torch.empty_like(scatter_sizes)
|
||||
dist.all_to_all_single(gather_sizes,
|
||||
scatter_sizes,
|
||||
group=ep_group.device_group)
|
||||
scatter_size_list = scatter_sizes.cpu().tolist()
|
||||
gather_size_list = gather_sizes.cpu().tolist()
|
||||
token_counts_combined = torch.stack(
|
||||
[gather_sizes, global_expert_tokens], dim=0)
|
||||
token_counts_combined = token_counts_combined.view(
|
||||
2, ep_group.world_size, -1).sum(dim=2)
|
||||
token_counts_combined_cpu = token_counts_combined.to(
|
||||
torch.device("cpu"), non_blocking=True).numpy()
|
||||
all_tokens = gather_sizes.sum()
|
||||
|
||||
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
||||
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||
scatter_size_list,
|
||||
gather_size_list)
|
||||
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
||||
scatter_size_list,
|
||||
gather_size_list)
|
||||
gathered_tokens = quantized_tokens.new_empty(all_tokens.item(),
|
||||
quantized_tokens.shape[1])
|
||||
dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0])
|
||||
gather_size_list = token_counts_combined_cpu[1]
|
||||
scatter_size_list = token_counts_combined_cpu[0]
|
||||
|
||||
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
||||
dist.all_to_all_single(gathered_tokens, quantized_tokens,
|
||||
scatter_size_list, gather_size_list)
|
||||
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
|
||||
gather_size_list)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
||||
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
|
||||
gathered_tokens,
|
||||
gather_sizes.view(ep_group.world_size, -1),
|
||||
per_token_scales=dynamic_scale)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
@@ -419,6 +446,7 @@ def fused_experts_with_all2all(
|
||||
expanded_expert_idx, num_experts)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
dynamic_scale = None
|
||||
|
||||
# `hidden_states` will be disposed in the `apply_mlp` function
|
||||
hidden_states = apply_mlp(
|
||||
@@ -428,14 +456,19 @@ def fused_experts_with_all2all(
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_tokens, #16
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
if expert_map is not None:
|
||||
resorted_idx = torch.argsort(sorted_idx)
|
||||
hidden_states = hidden_states[resorted_idx]
|
||||
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||
gather_size_list,
|
||||
scatter_size_list)
|
||||
reordered_outputs = torch.index_select(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
|
||||
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
|
||||
|
||||
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
|
||||
dist.all_to_all_single(hidden_states, reordered_outputs,
|
||||
gather_size_list, scatter_size_list)
|
||||
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
@@ -444,8 +477,8 @@ def fused_experts_with_all2all(
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
export_for_source_row=None,
|
||||
drop_pad_mode=2)
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
|
||||
Reference in New Issue
Block a user