### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
50 lines
2.1 KiB
Python
50 lines
2.1 KiB
Python
from unittest.mock import Mock, patch
|
|
|
|
import torch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.w8a8_dynamic import \
|
|
AscendW8A8DynamicFusedMoEMethod
|
|
|
|
|
|
class TestAscendW8A8FusedMoEMethod(TestBase):
|
|
num_experts = 8
|
|
hidden_size = 128
|
|
intermediate_size = 128
|
|
|
|
@patch("torch.distributed.get_rank")
|
|
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
|
|
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
|
|
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
|
|
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
|
|
mock_get_mc2_group, mock_get_rank):
|
|
mock_ep_group = Mock()
|
|
mock_get_ep_group.return_value = mock_ep_group
|
|
mock_ascend_config = Mock()
|
|
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
|
mock_get_ascend_config.return_value = mock_ascend_config
|
|
mock_mc2_group = Mock(device_group=0)
|
|
mock_get_mc2_group.return_value = mock_mc2_group
|
|
mock_rank = Mock()
|
|
mock_get_rank.return_value = mock_rank
|
|
|
|
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
|
|
|
|
def test_get_weight(self):
|
|
param_dict = self.quant_method.get_weight(self.num_experts,
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
torch.bfloat16)
|
|
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
|
self.assertEqual(
|
|
param_dict["w13_weight"].shape,
|
|
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
|
|
|
|
def test_get_dynamic_quant_param(self):
|
|
param_dict = self.quant_method.get_dynamic_quant_param(
|
|
self.num_experts, self.intermediate_size, self.hidden_size,
|
|
torch.bfloat16)
|
|
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
|
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
|
(self.num_experts, 2 * self.intermediate_size, 1))
|