Files
xc-llm-ascend/tests/ut/ops/test_linear.py
lidenghui1110 d65fb194d9 [Feat] Add custom Embedding tensor model parallel (#2616)
Similar to #2309 , this PR introduces Embedding tensor model parallel to
achieve decreasing of memory consumption. It support both eager mode and
graph mode.

And this PR refactor module tensor parallel configurations supported in
#2309, #2167, #2120, merge all config into `finegrained_tp_config` in
`additional_config`, including:
`lmhead_tensor_parallel_size`
`oproj_tensor_parallel_size`
`embedding_tensor_parallel_size`
`mlp_tensor_parallel_size`

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
2025-12-12 14:41:20 +08:00

158 lines
5.4 KiB
Python

import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm import config
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase):
def setUp(self):
self.mock_group = mock.MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2
self.mock_ascend_config.finegrained_tp_config.mlp_tensor_parallel_size = 2
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
return_value=self.mock_group),
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.linear_op.get_tp_group",
return_value=self.mock_group),
patch(
"vllm.distributed.parallel_state.get_tp_group",
return_value=self.mock_group,
),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
self.layer = mock.MagicMock()
mock_dtype = mock.PropertyMock(return_value=torch.float16)
type(self.layer.weight.data).dtype = mock_dtype
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_enable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_disable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False
ascend_config._ASCEND_CONFIG.finegrained_tp_config.mlp_tensor_parallel_size = 2
ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_oproj_tp(self):
config._current_vllm_config = MagicMock()
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False
ascend_config._ASCEND_CONFIG.finegrained_tp_config.oproj_tensor_parallel_size = 2
ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
def test_merged_mlp_tp_init(self):
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False
ascend_config._ASCEND_CONFIG.finegrained_tp_config.mlp_tensor_parallel_size = 2
ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False
linear = AscendMergedColumnParallelLinear(
input_size=16,
output_sizes=[8, 8],
prefix="gate_up_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendReplicatedLinear(BaseLinearTest):
def test_init_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
def test_init_without_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
if __name__ == '__main__':
unittest.main()