From 756b8a1946aa9396d5bc7b9c67547fcb93fad630 Mon Sep 17 00:00:00 2001 From: Yikun Jiang Date: Fri, 12 Sep 2025 20:51:12 +0800 Subject: [PATCH] Revert "[Feat] Unquantized linear nz support (#2619)" (#2896) ### What this PR does / why we need it? This reverts commit 7b2ecc1e9a64aeda78e2137aa06abdbf2890c000. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/64d90c3e4fe2a0e4395b8c94344dcdf78fa4cd22 Closes: https://github.com/vllm-project/vllm-ascend/issues/2890 Closes: https://github.com/vllm-project/vllm-ascend/issues/2887 Closes: https://github.com/vllm-project/vllm-ascend/issues/2885 Signed-off-by: Yikun Jiang --- tests/ut/ops/test_linear.py | 79 +--------------------- tests/ut/quantization/test_quant_config.py | 6 +- vllm_ascend/ops/linear.py | 30 +------- vllm_ascend/quantization/quant_config.py | 6 +- 4 files changed, 10 insertions(+), 111 deletions(-) diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 6d678bd..a0d7f06 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -5,13 +5,11 @@ from unittest.mock import MagicMock, patch import torch -from tests.ut.base import TestBase from vllm_ascend import ascend_config from vllm_ascend.distributed import parallel_state from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, - AscendRowParallelLinear, - AscendUnquantizedLinearMethod) + AscendRowParallelLinear) class BaseLinearTest(unittest.TestCase): @@ -48,81 +46,6 @@ class BaseLinearTest(unittest.TestCase): p.stop() -class TestAscendUnquantizedLinearMethod(TestBase): - - def setUp(self): - self.method = AscendUnquantizedLinearMethod() - - @mock.patch("torch_npu.npu_format_cast") - @mock.patch("torch.version") - def test_process_weights_after_loading_is_cann_8_3(self, mock_version, - mock_format_cast): - layer = mock.MagicMock() - - mock_version.cann = "8.3.RC1" - self.method.process_weights_after_loading(layer) - mock_format_cast.assert_called_once() - - @mock.patch("torch.version") - def test_process_weights_after_loading_not_cann_8_3(self, mock_version): - layer = mock.MagicMock() - - mock_version.cann = "8.2.RC1" - # Should not raise exception - self.method.process_weights_after_loading(layer) - - @mock.patch("torch.matmul") - @mock.patch("torch.version") - def test_apply_with_bias_is_cann_8_3(self, mock_version, mock_npu_matmul): - layer = mock.MagicMock() - layer.weight = torch.randn(128, 256) - - x = torch.randn(32, 128) - bias = torch.randn(256) - - expected_y_output = torch.randn(32, 256) - mock_npu_matmul.return_value = expected_y_output - - mock_version.cann = "8.3.RC1" - output = self.method.apply(layer, x, bias) - - expected_y_output += bias - self.assertTrue(torch.equal(output, expected_y_output)) - - @mock.patch("torch.matmul") - @mock.patch("torch.version") - def test_apply_without_bias_is_cann_8_3(self, mock_version, - mock_npu_matmul): - layer = mock.MagicMock() - layer.weight = torch.randn(128, 256) - - x = torch.randn(32, 128) - - expected_y_output = torch.randn(32, 256) - mock_npu_matmul.return_value = expected_y_output - - mock_version.cann = "8.3.RC1" - output = self.method.apply(layer, x) - - self.assertTrue(torch.equal(output, expected_y_output)) - - @mock.patch("torch.nn.functional.linear") - @mock.patch("torch.version") - def test_apply_not_cann_8_3(self, mock_version, mock_npu_linear): - layer = mock.MagicMock() - layer.weight = torch.randn(128, 256) - - x = torch.randn(32, 128) - - expected_y_output = torch.randn(32, 256) - mock_npu_linear.return_value = expected_y_output - - mock_version.cann = "8.2.RC1" - output = self.method.apply(layer, x) - - self.assertTrue(torch.equal(output, expected_y_output)) - - class TestAscendRowParallelLinear(BaseLinearTest): def test_mlp_optimize(self): diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index 07f546b..fa5d13e 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -4,10 +4,10 @@ import torch from vllm.attention.layer import Attention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig -from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) from tests.ut.base import TestBase -from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod, AscendQuantConfig) from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD @@ -79,7 +79,7 @@ class TestAscendQuantConfig(TestBase): 'is_layer_skipped_ascend', return_value=True): method = self.ascend_config.get_quant_method(linear_layer, ".attn") - self.assertIsInstance(method, AscendUnquantizedLinearMethod) + self.assertIsInstance(method, UnquantizedLinearMethod) # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 9b472a7..8ffce39 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -36,36 +36,12 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dense_optim_enable, - matmul_allreduce_enable, mlp_tp_enable, - oproj_tp_enable) +from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable, + mlp_tp_enable, oproj_tp_enable) _HCOMM_INFO = None -class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): - """Linear method without quantization.""" - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) - if torch.version.cann.startswith("8.3"): - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight.data = torch_npu.npu_format_cast( - layer.weight.data, ACL_FORMAT_FRACTAL_NZ) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if torch.version.cann.startswith("8.3"): - if bias is None: - return torch.matmul(x, layer.weight) - else: - return torch.matmul(x, layer.weight) + bias - else: - return torch.nn.functional.linear(x, layer.weight, bias) - - class AscendColumnParallelLinear(ColumnParallelLinear): """Linear layer with column parallelism. @@ -642,7 +618,7 @@ class AscendLinearBase(LinearBase): self.prefix = prefix if quant_config is None: self.quant_method: Optional[ - QuantizeMethodBase] = AscendUnquantizedLinearMethod() + QuantizeMethodBase] = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index fb644a1..6124fcb 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -23,7 +23,8 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - RowParallelLinear) + RowParallelLinear, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import \ register_quantization_config from vllm.model_executor.layers.quantization.base_config import ( @@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, oproj_tp_enable) @@ -95,7 +95,7 @@ class AscendQuantConfig(QuantizationConfig): if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): - return AscendUnquantizedLinearMethod() + return UnquantizedLinearMethod() return AscendLinearMethod(self, prefix, self.packed_modules_mapping) elif isinstance(layer, Attention) and \