Files
xc-llm-ascend/tests/ut/_310p/quantization/test_modelslim_config_310.py
pu-zhe 02886e2641 [Feat] 310p support MoE W8A8 quantizaition (#6641)
### What this PR does / why we need it?
This PR introduces support for W8A8 dynamic quantization for
Mixture-of-Experts (MoE) models on Ascend 310P devices. This is achieved
by:
- Implementing a new quantization scheme
`AscendW8A8DynamicFusedMoEMethod310`.
- Adding a unified MLP implementation (`unified_apply_mlp`) for 310P
that handles both quantized and unquantized paths.
- Refactoring the MoE and quantization configuration logic to correctly
route to the new 310P-specific implementations.
- Adding new e2e and unit tests to verify the functionality of MoE W8A8
quantization.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- Added a new e2e test `test_qwen3_moe_tp2_w8a8` to test MoE W8A8
quantization in a multi-card setup.
- Added several new unit tests for the 310P-specific MoE components,
including `experts_selector`, `fused_moe`, `moe_comm_method`, `moe_mlp`,
and the new `w8a8_dynamic` quantization method.

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
2026-02-10 17:17:44 +08:00

105 lines
5.3 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig, FusedMoEParallelConfig
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend._310p.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod310
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
class TestAscendModelSlimConfig310(TestBase):
def setUp(self):
self.sample_config = {
"weight": "INT8",
"layer1.weight": "INT8",
"layer2.weight": "FLOAT",
"fused_layer.weight": "FLOAT",
"fused_layer.shard1.weight": "FLOAT",
"fused_layer.shard2.weight": "FLOAT",
"shard1.weight": "FLOAT",
"shard2.weight": "FLOAT",
}
self.ascend_config = AscendModelSlimConfig310(self.sample_config)
self.ascend_config.packed_modules_mapping = None
def test_get_quant_method_for_linear_310(self):
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
with (
patch("vllm_ascend._310p.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True),
):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
mock_scheme = MagicMock()
with (
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=False),
patch("vllm_ascend._310p.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend._310p.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme),
patch(
"vllm_ascend._310p.quantization.modelslim_config.AscendLinearMethod", return_value=MagicMock()
) as mock_ascend_linear,
):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIs(method, mock_ascend_linear.return_value)
mock_ascend_linear.assert_called_once_with(mock_scheme)
def test_get_quant_method_for_fused_moe_310(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config.moe_parallel_config = MagicMock(spec=FusedMoEParallelConfig)
fused_moe_layer.moe_config.moe_parallel_config.use_ep = True
fused_moe_layer.moe_config.moe_parallel_config.dp_size = 1
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
mock_config.compilation_config.custom_ops = ["all"]
mock_scheme = MagicMock()
# Test skipped layer
with (
patch("vllm.config.vllm.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend._310p.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True),
):
method = self.ascend_config.get_quant_method(fused_moe_layer, ".moe")
self.assertIsInstance(method, AscendUnquantizedFusedMoEMethod310)
# Test quantized layer
mock_scheme = MagicMock()
with (
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=False),
patch("vllm.config.vllm.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend._310p.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend._310p.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme),
patch(
"vllm_ascend._310p.quantization.modelslim_config.AscendFusedMoEMethod", return_value=MagicMock()
) as fused_moe_method,
):
method = self.ascend_config.get_quant_method(fused_moe_layer, ".moe")
self.assertIs(method, fused_moe_method.return_value)
fused_moe_method.assert_called_once_with(mock_scheme, fused_moe_layer.moe_config)