[Feat] Unquantized linear nz support (#2619)

### What this PR does / why we need it?
Currently, when executing to the Linear layer of the model in
vLLM-Ascend, the weights input format is ND in unquantized case and
skipped ascend case, which is slower than FRACTAL_NZ.
This PR supplements the execution logic for Linear layer. When
VLLM_ASCEND_ENABLE_MLP_OPTIMIZE=1 and CANN version is 8.3, the weights
of the Linear layer will be converted to FRACTAL_NZ, in both unquantized
case and skipped ascend case.

- vLLM version: main
- vLLM main:
267c80d31f

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-09-11 11:40:00 +08:00
committed by GitHub
parent 5691104249
commit 7b2ecc1e9a
4 changed files with 111 additions and 10 deletions

View File

@@ -5,11 +5,13 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendRowParallelLinear)
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase):
@@ -46,6 +48,81 @@ class BaseLinearTest(unittest.TestCase):
p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_cann_8_3(self, mock_version,
mock_format_cast):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
self.method.process_weights_after_loading(layer)
mock_format_cast.assert_called_once()
@mock.patch("torch.version")
def test_process_weights_after_loading_not_cann_8_3(self, mock_version):
layer = mock.MagicMock()
mock_version.cann = "8.2.RC1"
# Should not raise exception
self.method.process_weights_after_loading(layer)
@mock.patch("torch.matmul")
@mock.patch("torch.version")
def test_apply_with_bias_is_cann_8_3(self, mock_version, mock_npu_matmul):
layer = mock.MagicMock()
layer.weight = torch.randn(128, 256)
x = torch.randn(32, 128)
bias = torch.randn(256)
expected_y_output = torch.randn(32, 256)
mock_npu_matmul.return_value = expected_y_output
mock_version.cann = "8.3.RC1"
output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@mock.patch("torch.matmul")
@mock.patch("torch.version")
def test_apply_without_bias_is_cann_8_3(self, mock_version,
mock_npu_matmul):
layer = mock.MagicMock()
layer.weight = torch.randn(128, 256)
x = torch.randn(32, 128)
expected_y_output = torch.randn(32, 256)
mock_npu_matmul.return_value = expected_y_output
mock_version.cann = "8.3.RC1"
output = self.method.apply(layer, x)
self.assertTrue(torch.equal(output, expected_y_output))
@mock.patch("torch.nn.functional.linear")
@mock.patch("torch.version")
def test_apply_not_cann_8_3(self, mock_version, mock_npu_linear):
layer = mock.MagicMock()
layer.weight = torch.randn(128, 256)
x = torch.randn(32, 128)
expected_y_output = torch.randn(32, 256)
mock_npu_linear.return_value = expected_y_output
mock_version.cann = "8.2.RC1"
output = self.method.apply(layer, x)
self.assertTrue(torch.equal(output, expected_y_output))
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):

View File

@@ -4,10 +4,10 @@ import torch
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
AscendQuantConfig)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
@@ -79,7 +79,7 @@ class TestAscendQuantConfig(TestBase):
'is_layer_skipped_ascend',
return_value=True):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, UnquantizedLinearMethod)
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \

View File

@@ -36,12 +36,36 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
mlp_tp_enable, oproj_tp_enable)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dense_optim_enable,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
_HCOMM_INFO = None
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
"""Linear method without quantization."""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if torch.version.cann.startswith("8.3"):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if torch.version.cann.startswith("8.3"):
if bias is None:
return torch.matmul(x, layer.weight)
else:
return torch.matmul(x, layer.weight) + bias
else:
return torch.nn.functional.linear(x, layer.weight, bias)
class AscendColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
@@ -617,7 +641,7 @@ class AscendLinearBase(LinearBase):
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)

View File

@@ -23,8 +23,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear,
UnquantizedLinearMethod)
RowParallelLinear)
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
@@ -38,6 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable)
@@ -92,7 +92,7 @@ class AscendQuantConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedLinearMethod()
return AscendUnquantizedLinearMethod()
return AscendLinearMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, Attention) and \