From 7b2ecc1e9a64aeda78e2137aa06abdbf2890c000 Mon Sep 17 00:00:00 2001 From: anon189Ty Date: Thu, 11 Sep 2025 11:40:00 +0800 Subject: [PATCH] [Feat] Unquantized linear nz support (#2619) ### What this PR does / why we need it? Currently, when executing to the Linear layer of the model in vLLM-Ascend, the weights input format is ND in unquantized case and skipped ascend case, which is slower than FRACTAL_NZ. This PR supplements the execution logic for Linear layer. When VLLM_ASCEND_ENABLE_MLP_OPTIMIZE=1 and CANN version is 8.3, the weights of the Linear layer will be converted to FRACTAL_NZ, in both unquantized case and skipped ascend case. - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/267c80d31f6b77092a5d5903da64556ac15c4d4d Signed-off-by: anon189Ty --- tests/ut/ops/test_linear.py | 79 +++++++++++++++++++++- tests/ut/quantization/test_quant_config.py | 6 +- vllm_ascend/ops/linear.py | 30 +++++++- vllm_ascend/quantization/quant_config.py | 6 +- 4 files changed, 111 insertions(+), 10 deletions(-) diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index a0d7f06..6d678bd 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -5,11 +5,13 @@ from unittest.mock import MagicMock, patch import torch +from tests.ut.base import TestBase from vllm_ascend import ascend_config from vllm_ascend.distributed import parallel_state from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, - AscendRowParallelLinear) + AscendRowParallelLinear, + AscendUnquantizedLinearMethod) class BaseLinearTest(unittest.TestCase): @@ -46,6 +48,81 @@ class BaseLinearTest(unittest.TestCase): p.stop() +class TestAscendUnquantizedLinearMethod(TestBase): + + def setUp(self): + self.method = AscendUnquantizedLinearMethod() + + @mock.patch("torch_npu.npu_format_cast") + @mock.patch("torch.version") + def test_process_weights_after_loading_is_cann_8_3(self, mock_version, + mock_format_cast): + layer = mock.MagicMock() + + mock_version.cann = "8.3.RC1" + self.method.process_weights_after_loading(layer) + mock_format_cast.assert_called_once() + + @mock.patch("torch.version") + def test_process_weights_after_loading_not_cann_8_3(self, mock_version): + layer = mock.MagicMock() + + mock_version.cann = "8.2.RC1" + # Should not raise exception + self.method.process_weights_after_loading(layer) + + @mock.patch("torch.matmul") + @mock.patch("torch.version") + def test_apply_with_bias_is_cann_8_3(self, mock_version, mock_npu_matmul): + layer = mock.MagicMock() + layer.weight = torch.randn(128, 256) + + x = torch.randn(32, 128) + bias = torch.randn(256) + + expected_y_output = torch.randn(32, 256) + mock_npu_matmul.return_value = expected_y_output + + mock_version.cann = "8.3.RC1" + output = self.method.apply(layer, x, bias) + + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + + @mock.patch("torch.matmul") + @mock.patch("torch.version") + def test_apply_without_bias_is_cann_8_3(self, mock_version, + mock_npu_matmul): + layer = mock.MagicMock() + layer.weight = torch.randn(128, 256) + + x = torch.randn(32, 128) + + expected_y_output = torch.randn(32, 256) + mock_npu_matmul.return_value = expected_y_output + + mock_version.cann = "8.3.RC1" + output = self.method.apply(layer, x) + + self.assertTrue(torch.equal(output, expected_y_output)) + + @mock.patch("torch.nn.functional.linear") + @mock.patch("torch.version") + def test_apply_not_cann_8_3(self, mock_version, mock_npu_linear): + layer = mock.MagicMock() + layer.weight = torch.randn(128, 256) + + x = torch.randn(32, 128) + + expected_y_output = torch.randn(32, 256) + mock_npu_linear.return_value = expected_y_output + + mock_version.cann = "8.2.RC1" + output = self.method.apply(layer, x) + + self.assertTrue(torch.equal(output, expected_y_output)) + + class TestAscendRowParallelLinear(BaseLinearTest): def test_mlp_optimize(self): diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index fa5d13e..07f546b 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -4,10 +4,10 @@ import torch from vllm.attention.layer import Attention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import LinearBase from tests.ut.base import TestBase +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod, AscendQuantConfig) from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD @@ -79,7 +79,7 @@ class TestAscendQuantConfig(TestBase): 'is_layer_skipped_ascend', return_value=True): method = self.ascend_config.get_quant_method(linear_layer, ".attn") - self.assertIsInstance(method, UnquantizedLinearMethod) + self.assertIsInstance(method, AscendUnquantizedLinearMethod) # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 6bf9676..0c6f430 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -36,12 +36,36 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable, - mlp_tp_enable, oproj_tp_enable) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dense_optim_enable, + matmul_allreduce_enable, mlp_tp_enable, + oproj_tp_enable) _HCOMM_INFO = None +class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): + """Linear method without quantization.""" + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + if torch.version.cann.startswith("8.3"): + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if torch.version.cann.startswith("8.3"): + if bias is None: + return torch.matmul(x, layer.weight) + else: + return torch.matmul(x, layer.weight) + bias + else: + return torch.nn.functional.linear(x, layer.weight, bias) + + class AscendColumnParallelLinear(ColumnParallelLinear): """Linear layer with column parallelism. @@ -617,7 +641,7 @@ class AscendLinearBase(LinearBase): self.prefix = prefix if quant_config is None: self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedLinearMethod() + QuantizeMethodBase] = AscendUnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 9cf84e8..95cc02c 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -23,8 +23,7 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - RowParallelLinear, - UnquantizedLinearMethod) + RowParallelLinear) from vllm.model_executor.layers.quantization import \ register_quantization_config from vllm.model_executor.layers.quantization.base_config import ( @@ -38,6 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, oproj_tp_enable) @@ -92,7 +92,7 @@ class AscendQuantConfig(QuantizationConfig): if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): - return UnquantizedLinearMethod() + return AscendUnquantizedLinearMethod() return AscendLinearMethod(self, prefix, self.packed_modules_mapping) elif isinstance(layer, Attention) and \