[Feat] Unquantized Linear to nz and control all nz-cast (#3356)

### What this PR does / why we need it?
Currently, when executing to the Linear layer of models in vLLM-Ascend,
the weights format is ND in unquantized case and skipped ascend case.
This PR supplements the execution logic for Linear layer. We use a new
global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and
CANN version is 8.3, the weights of the Linear layer will be converted
to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also
use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as
w8a8-quantized case.

### Does this PR introduce _any_ user-facing change?
Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ
format, you should set VLLM_ASCEND_ENABLE_NZ=1.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-14 17:39:26 +08:00
committed by GitHub
parent 5c45c227dc
commit 07e39620ea
22 changed files with 413 additions and 49 deletions

View File

@@ -5,10 +5,13 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendRowParallelLinear)
AscendReplicatedLinear,
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase):
@@ -49,6 +52,47 @@ class BaseLinearTest(unittest.TestCase):
p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_enable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_disable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(layer)
mock_format_cast.assert_not_called()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch.version")
def test_process_weights_after_loading_not_8_3(self, mock_version,
mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.2.RC1"
mock_is_nz.return_value = 1
# Should not raise exception
self.method.process_weights_after_loading(layer)
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
@@ -92,5 +136,24 @@ class TestAscendMergedColumnParallelLinear(BaseLinearTest):
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendReplicatedLinear(BaseLinearTest):
def test_init_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
def test_init_without_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
if __name__ == '__main__':
unittest.main()