[Refactor]refactor 310p ops and add ut (#6591)

### What this PR does / why we need it?
This pull request focuses on a significant refactoring effort within the
vllm-ascend project, specifically targeting operations optimized for the
Ascend 310P hardware. The changes aim to streamline the implementation
of core components like quantization and multi-head attention, making
the codebase more maintainable and robust. Concurrently, new unit tests
have been introduced to ensure the correctness and reliability of these
refactored modules.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
E2E test with qwen3-32b w8a8

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-02-07 09:25:17 +08:00
committed by GitHub
parent 6c49f95da2
commit 23524f2ca4
6 changed files with 173 additions and 28 deletions

View File

@@ -0,0 +1,67 @@
from unittest.mock import MagicMock, patch
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
class TestAscendModelSlimConfig310(TestBase):
def setUp(self):
self.sample_config = {
"weight": "INT8",
"layer1.weight": "INT8",
"layer2.weight": "FLOAT",
"fused_layer.weight": "FLOAT",
"fused_layer.shard1.weight": "FLOAT",
"fused_layer.shard2.weight": "FLOAT",
"shard1.weight": "FLOAT",
"shard2.weight": "FLOAT",
}
self.ascend_config = AscendModelSlimConfig310(self.sample_config)
self.ascend_config.packed_modules_mapping = None
def test_get_quant_method_for_linear_310(self):
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
with (
patch("vllm_ascend._310p.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True)
):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
mock_scheme = MagicMock()
with (
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=False),
patch("vllm_ascend._310p.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend._310p.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme),
patch(
"vllm_ascend._310p.quantization.modelslim_config.AscendLinearMethod", return_value=MagicMock()
) as mock_ascend_linear,
):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIs(method, mock_ascend_linear.return_value)
mock_ascend_linear.assert_called_once_with(mock_scheme)
def test_get_quant_method_for_fused_moe_310(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
mock_scheme = MagicMock()
with (
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=False),
patch("vllm_ascend._310p.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config),
patch("vllm_ascend._310p.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme),
patch("vllm_ascend._310p.quantization.modelslim_config.AscendLinearMethod", return_value=MagicMock()),
self.assertRaises(NotImplementedError),
):
self.ascend_config.get_quant_method(fused_moe_layer, "moe_layer")

View File

@@ -0,0 +1,90 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend._310p.quantization.methods.w8a8_static import AscendW8A8LinearMethod310
class TestAscendW8A8LinearMethod310(TestBase):
def setUp(self):
self.method = AscendW8A8LinearMethod310()
def test_get_weight_310(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (20, 10))
def test_get_pertensor_param_310(self):
params = self.method.get_pertensor_param(torch.bfloat16)
self.assertEqual(params["input_scale"].dtype, torch.bfloat16)
self.assertEqual(params["input_offset"].dtype, torch.int8)
self.assertEqual(params["input_scale"].shape, (1,))
self.assertEqual(params["input_offset"].shape, (1,))
def test_get_perchannel_param_310(self):
params = self.method.get_perchannel_param(10, torch.bfloat16)
self.assertEqual(params["quant_bias"].dtype, torch.int32)
self.assertEqual(params["deq_scale"].dtype, torch.float32)
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
self.assertEqual(params["quant_bias"].shape, (10,))
self.assertEqual(params["deq_scale"].shape, (10,))
self.assertEqual(params["weight_scale"].shape, (10, 1))
self.assertEqual(params["weight_offset"].shape, (10, 1))
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8_310(self, mock_npu_quant_matmul, mock_quantize):
layer = MagicMock()
layer.aclnn_input_scale = torch.randn(256)
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale
layer.aclnn_input_offset = torch.randint(-128, 127, (256,), dtype=torch.int8)
layer.weight = torch.randn(128, 256)
layer.deq_scale = torch.randn(128)
layer.quant_bias = torch.randint(-128, 127, (256,))
layer.params_dtype = torch.float16
x = torch.randn(32, 128)
expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8)
mock_quantize.return_value = expect_x_output
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, tp_rank=0)
mock_quantize.assert_called_with(
x, layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset
)
mock_npu_quant_matmul.assert_called_with(
expect_x_output, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
)
# The bias is added by the linear layer's forward pass, not the quant method.
self.assertTrue(torch.equal(output, expected_y_output))
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_is_int8_310(self, mock_npu_quant_matmul, mock_quantize):
layer = MagicMock()
layer.aclnn_input_scale = torch.randn(256)
layer.aclnn_input_offset = torch.randint(-128, 127, (256,), dtype=torch.int8)
layer.weight = torch.randn(128, 256)
layer.deq_scale = torch.randn(128)
layer.quant_bias = torch.randint(-128, 127, (256,))
layer.params_dtype = torch.float16
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, tp_rank=0)
mock_quantize.assert_not_called()
mock_npu_quant_matmul.assert_called_with(
x, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
)
# The bias is added by the linear layer's forward pass, not the quant method.
self.assertTrue(torch.equal(output, expected_y_output))