Files
xc-llm-ascend/tests/ut/ops/test_linear.py
wangxiyuan f10acddb78 drop ascend scheduler (#4498)
Ascend scheduler was added for non chunk prefill case before, since that
the npu ops didn't work well with chunked prefill.

Now the ops with chunked prefill work better, it's time to remove the
ascend scheduler to use vLLM default scheduler.

- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-11-29 16:18:34 +08:00

148 lines
4.7 KiB
Python

import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm import config
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase):
def setUp(self):
self.mock_group = mock.MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.oproj_tensor_parallel_size = 2
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
return_value=self.mock_group),
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.linear_op.get_tp_group",
return_value=self.mock_group),
patch(
"vllm.distributed.parallel_state.get_tp_group",
return_value=self.mock_group,
),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
self.layer = mock.MagicMock()
mock_dtype = mock.PropertyMock(return_value=torch.float16)
type(self.layer.weight.data).dtype = mock_dtype
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_enable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_disable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_oproj_tp(self):
config._current_vllm_config = MagicMock()
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
def test_merged_mlp_tp_init(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendMergedColumnParallelLinear(
input_size=16,
output_sizes=[8, 8],
prefix="gate_up_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendReplicatedLinear(BaseLinearTest):
def test_init_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
def test_init_without_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
if __name__ == '__main__':
unittest.main()