This PR significantly optimizes performance for quantized Mixture of
Experts (MoE) layers by changing the order of quantization and
communication operations.
In the previous implementation, the `all2all` operation was performed on
unquantized `hidden_states` (in FP16/BF16) *before* quantization,
resulting in substantial communication overhead. By performing
quantization on each EP rank **first** and then sending the much smaller
quantized data, we reduce the communication volume by nearly 50%.
Additionally, this PR includes a minor optimization to cast `int` inputs
to `float` for the `argsort` operation, forcing it to run on a faster
NPU core instead of the AICPU.
These changes lead to a clear and significant performance gain in MoE
quantization scenarios.
- vLLM version: v0.10.0
- vLLM main:
7175817637
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
76 lines
3.1 KiB
Python
76 lines
3.1 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all
|
|
|
|
|
|
class TestAscendW8A8FusedMoEMethod(TestBase):
|
|
|
|
def setUp(self):
|
|
self.hidden_size = 128
|
|
self.num_tokens = 128
|
|
self.placeholder = torch.randn(self.num_tokens,
|
|
self.hidden_size,
|
|
dtype=torch.bfloat16)
|
|
|
|
@patch("torch.distributed.all_to_all_single")
|
|
@patch("torch_npu.npu_moe_re_routing")
|
|
@patch("torch_npu.npu_grouped_matmul")
|
|
@patch("torch_npu.npu_swiglu")
|
|
@patch("torch_npu.npu_dynamic_quant")
|
|
@patch("torch_npu.npu_moe_finalize_routing")
|
|
@patch("torch_npu.npu_moe_init_routing")
|
|
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
|
|
mock_moe_finalize_routing,
|
|
mock_dynamic_quant, mock_swiglu,
|
|
mock_grouped_matmul,
|
|
mock_moe_re_routing,
|
|
mock_all_to_all_single):
|
|
expert_map = MagicMock()
|
|
ep_group = MagicMock()
|
|
placeholder_int8 = torch.randint(0,
|
|
100,
|
|
(self.num_tokens, self.hidden_size),
|
|
dtype=torch.int8)
|
|
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
|
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
|
input)
|
|
mock_moe_init_routing.return_value = (
|
|
placeholder_int8,
|
|
placeholder_ones,
|
|
placeholder_ones,
|
|
)
|
|
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
|
torch.randint(0,
|
|
100,
|
|
(self.num_tokens, ),
|
|
dtype=torch.int32),
|
|
self.placeholder)
|
|
mock_grouped_matmul.return_value = self.placeholder
|
|
mock_swiglu.return_value = self.placeholder
|
|
mock_dynamic_quant.return_value = (
|
|
placeholder_int8,
|
|
torch.randn(self.num_tokens),
|
|
)
|
|
mock_moe_finalize_routing.return_value = self.placeholder
|
|
|
|
result = fused_experts_with_all2all(
|
|
hidden_states=self.placeholder,
|
|
w1=self.placeholder,
|
|
w1_scale=self.placeholder,
|
|
w2=self.placeholder,
|
|
w2_scale=self.placeholder,
|
|
topk_weights=self.placeholder,
|
|
topk_ids=self.placeholder,
|
|
top_k=8,
|
|
expert_map=expert_map,
|
|
ep_group=ep_group,
|
|
log2phy=None,
|
|
global_redundant_expert_num=256,
|
|
)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result.dtype, torch.bfloat16)
|
|
self.assertEqual(result.shape, (128, 128))
|