Files
xc-llm-ascend/tests/ut/ops/test_linear.py
Mengqing Cao cea0755b07 [1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
17c540a993

1. refactor deepseek to the latest code arch as of
17c540a993
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00

162 lines
5.3 KiB
Python

import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm import config
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase):
def setUp(self):
self.mock_group = mock.MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.oproj_tensor_parallel_size = 2
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
return_value=self.mock_group),
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.linear_op.get_tp_group",
return_value=self.mock_group),
patch(
"vllm.distributed.parallel_state.get_tp_group",
return_value=self.mock_group,
),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
self.layer = mock.MagicMock()
mock_dtype = mock.PropertyMock(return_value=torch.float16)
type(self.layer.weight.data).dtype = mock_dtype
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_enable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_disable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch.version")
def test_process_weights_after_loading_not_8_3(self, mock_version,
mock_is_nz):
mock_version.cann = "8.2.RC1"
mock_is_nz.return_value = 1
# Should not raise exception
self.method.process_weights_after_loading(self.layer)
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_oproj_tp(self):
config._current_vllm_config = MagicMock()
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
def test_merged_mlp_tp_init(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendMergedColumnParallelLinear(
input_size=16,
output_sizes=[8, 8],
prefix="gate_up_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendReplicatedLinear(BaseLinearTest):
def test_init_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
def test_init_without_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
if __name__ == '__main__':
unittest.main()