[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in W8A8_DYNAMIC quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.layers.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -524,6 +524,43 @@ class TestExpertsSelector:
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
|
||||
class TestCumsumGroupList(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.active_num = 8
|
||||
self.expert_num = 128
|
||||
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
|
||||
self.experts[:self.active_num] = 1
|
||||
self.experts = self.experts[torch.randperm(self.expert_num)]
|
||||
self.group_list = self.experts.cumsum(dim=0)
|
||||
|
||||
def test_cumsum_group_list_with_type_0(self):
|
||||
group_list = self.experts.cumsum(dim=0)
|
||||
group_list_type = 0
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_1(self):
|
||||
group_list = self.experts
|
||||
group_list_type = 1
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_2(self):
|
||||
tokens = torch.arange(self.expert_num, dtype=torch.int64)
|
||||
group_list = torch.cat([
|
||||
tokens.reshape(self.expert_num, 1),
|
||||
self.experts.reshape(self.expert_num, 1)
|
||||
],
|
||||
dim=1)
|
||||
group_list_type = 2
|
||||
result = cumsum_group_list(group_list,
|
||||
group_list_type,
|
||||
active_num=self.active_num,
|
||||
expert_num=self.expert_num)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@@ -739,3 +776,68 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch("vllm_ascend.ops.layers.moe_mlp.get_forward_context")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
|
||||
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
|
||||
mock_npu_swiglu, mock_npu_grouped_matmul,
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = "NOT_MC2"
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
|
||||
-128, 127, (10, 40),
|
||||
dtype=torch.int8), torch.rand(
|
||||
10, 1,
|
||||
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
]]
|
||||
mock_npu_swiglu.return_value = torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16)
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 40),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True,
|
||||
fusion=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
mock_npu_grouped_matmul.assert_called_once()
|
||||
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
||||
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
49
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
49
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
num_experts = 8
|
||||
hidden_size = 128
|
||||
intermediate_size = 128
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
|
||||
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
|
||||
mock_get_mc2_group, mock_get_rank):
|
||||
mock_ep_group = Mock()
|
||||
mock_get_ep_group.return_value = mock_ep_group
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_mc2_group = Mock(device_group=0)
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
mock_rank = Mock()
|
||||
mock_get_rank.return_value = mock_rank
|
||||
|
||||
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
param_dict = self.quant_method.get_weight(self.num_experts,
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(
|
||||
param_dict["w13_weight"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.num_experts, self.intermediate_size, self.hidden_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, 1))
|
||||
Reference in New Issue
Block a user