Files
xc-llm-ascend/tests/ut/_310p/quantization/test_w8a8_static_310.py
pu-zhe 02886e2641 [Feat] 310p support MoE W8A8 quantizaition (#6641)
### What this PR does / why we need it?
This PR introduces support for W8A8 dynamic quantization for
Mixture-of-Experts (MoE) models on Ascend 310P devices. This is achieved
by:
- Implementing a new quantization scheme
`AscendW8A8DynamicFusedMoEMethod310`.
- Adding a unified MLP implementation (`unified_apply_mlp`) for 310P
that handles both quantized and unquantized paths.
- Refactoring the MoE and quantization configuration logic to correctly
route to the new 310P-specific implementations.
- Adding new e2e and unit tests to verify the functionality of MoE W8A8
quantization.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- Added a new e2e test `test_qwen3_moe_tp2_w8a8` to test MoE W8A8
quantization in a multi-card setup.
- Added several new unit tests for the 310P-specific MoE components,
including `experts_selector`, `fused_moe`, `moe_comm_method`, `moe_mlp`,
and the new `w8a8_dynamic` quantization method.

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
2026-02-10 17:17:44 +08:00

106 lines
4.5 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend._310p.quantization.methods.w8a8_static import AscendW8A8LinearMethod310
class TestAscendW8A8LinearMethod310(TestBase):
def setUp(self):
self.method = AscendW8A8LinearMethod310()
def test_get_weight_310(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (20, 10))
def test_get_pertensor_param_310(self):
params = self.method.get_pertensor_param(torch.float16)
self.assertEqual(params["input_scale"].dtype, torch.float16)
self.assertEqual(params["input_offset"].dtype, torch.int8)
self.assertEqual(params["input_scale"].shape, (1,))
self.assertEqual(params["input_offset"].shape, (1,))
def test_get_perchannel_param_310(self):
params = self.method.get_perchannel_param(10, torch.float16)
self.assertEqual(params["quant_bias"].dtype, torch.int32)
self.assertEqual(params["deq_scale"].dtype, torch.int64)
self.assertEqual(params["weight_scale"].dtype, torch.float16)
self.assertEqual(params["weight_offset"].dtype, torch.float16)
self.assertEqual(params["quant_bias"].shape, (10,))
self.assertEqual(params["deq_scale"].shape, (10,))
self.assertEqual(params["weight_scale"].shape, (10, 1))
self.assertEqual(params["weight_offset"].shape, (10, 1))
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8_310(self, mock_npu_quant_matmul, mock_quantize):
layer = MagicMock()
layer.aclnn_input_scale = torch.randn(256)
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale
layer.aclnn_input_offset = torch.randint(-128, 127, (256,), dtype=torch.int8)
layer.weight = torch.randn(128, 256)
layer.deq_scale = torch.randn(128)
layer.quant_bias = torch.randint(-128, 127, (256,))
layer.params_dtype = torch.float16
x = torch.randn(32, 128)
expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8)
mock_quantize.return_value = expect_x_output
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, tp_rank=0)
mock_quantize.assert_called_with(
x, layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset
)
mock_npu_quant_matmul.assert_called_with(
expect_x_output, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
)
# The bias is added by the linear layer's forward pass, not the quant method.
self.assertTrue(torch.equal(output, expected_y_output))
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_is_int8_310(self, mock_npu_quant_matmul, mock_quantize):
layer = MagicMock()
layer.aclnn_input_scale = torch.randn(256)
layer.aclnn_input_offset = torch.randint(-128, 127, (256,), dtype=torch.int8)
layer.weight = torch.randn(128, 256)
layer.deq_scale = torch.randn(128)
layer.quant_bias = torch.randint(-128, 127, (256,))
layer.params_dtype = torch.float16
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, tp_rank=0)
mock_quantize.assert_not_called()
mock_npu_quant_matmul.assert_called_with(
x, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
)
# The bias is added by the linear layer's forward pass, not the quant method.
self.assertTrue(torch.equal(output, expected_y_output))