Files
xc-llm-ascend/tests/ut/ops/test_linear.py
realliujiaxu af2a886814 refactor linear (#2867)
### What this PR does / why we need it?
The current linear.py has the following issues:

- There is redundant conditional logic in the `comm_group` and `forward`
selection for classes such as `AscendMergedColumnParallelLinear`.

- Inconsistent comm_group selection logic exists among
`AscendMergedColumnParallelLinear`, `AscendColumnParallelLinear`, and
`AscendQKVParallelLinear`.

To address these two issues, this PR encapsulates `comm_group` and
`forward` into classes and extracts the classes selection logic into
common functions. For future additions of custom communication groups or
forward methods, it will only be necessary to extend
`CustomColumnParallelOp` or `CustomRowParallelOp` and add new selection
logic.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?


- vLLM version: v0.10.2
- vLLM main:
dd39baf717

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Co-authored-by: weijinqian0 <weijinqian@huawei.com>
2025-09-18 14:09:19 +08:00

93 lines
2.8 KiB
Python

import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendRowParallelLinear)
class BaseLinearTest(unittest.TestCase):
def setUp(self):
self.mock_group = mock.MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.oproj_tensor_parallel_size = 2
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
return_value=self.mock_group),
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.linear_op.get_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_oproj_tp(self):
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
def test_merged_mlp_tp_init(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendMergedColumnParallelLinear(
input_size=16,
output_sizes=[8, 8],
prefix="gate_up_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
if __name__ == '__main__':
unittest.main()