Files
xc-llm-ascend/tests/ut/ops/test_linear.py
rjg-lyh 74903af460 [v0.11.0][refactor] refactor SequenceRowParallelOp forward (#3654)
### What this PR does / why we need it?
This PR refactors SequenceRowParallelOp forward. In order to further
expand the operator inclusion scope in dynamic judgment scenarios, this
PR customizes the entire matmul computation and communication as a
custom operator masking. With this refactor, it will support directly
writing code such as common operation fusion into the
SequenceRowParallelOp class's member function matmul_and_reduce, without
the need to register more redundant custom masking operators.

### How was this patch tested?
CI passed with new added/existing test.

Signed-off-by: rjg-lyh <1318825571@qq.com>
2025-10-23 14:45:49 +08:00

161 lines
5.2 KiB
Python

import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm import config
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase):
def setUp(self):
self.mock_group = mock.MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.oproj_tensor_parallel_size = 2
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
return_value=self.mock_group),
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.linear_op.get_tp_group",
return_value=self.mock_group),
patch(
"vllm.distributed.parallel_state.get_tp_group",
return_value=self.mock_group,
),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
self.layer = mock.MagicMock()
mock_dtype = mock.PropertyMock(return_value=torch.float16)
type(self.layer.weight.data).dtype = mock_dtype
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_enable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_disable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch.version")
def test_process_weights_after_loading_not_8_3(self, mock_version,
mock_is_nz):
mock_version.cann = "8.2.RC1"
mock_is_nz.return_value = 1
# Should not raise exception
self.method.process_weights_after_loading(self.layer)
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_oproj_tp(self):
config._current_vllm_config = MagicMock()
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
def test_merged_mlp_tp_init(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendMergedColumnParallelLinear(
input_size=16,
output_sizes=[8, 8],
prefix="gate_up_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendReplicatedLinear(BaseLinearTest):
def test_init_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
def test_init_without_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
if __name__ == '__main__':
unittest.main()