Ascend scheduler was added for non chunk prefill case before, since that the npu ops didn't work well with chunked prefill. Now the ops with chunked prefill work better, it's time to remove the ascend scheduler to use vLLM default scheduler. - vLLM version: v0.11.2 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
import os
|
|
import unittest
|
|
from unittest import mock
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from vllm import config
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend import ascend_config
|
|
from vllm_ascend.distributed import parallel_state
|
|
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
|
|
AscendReplicatedLinear,
|
|
AscendRowParallelLinear,
|
|
AscendUnquantizedLinearMethod)
|
|
|
|
|
|
class BaseLinearTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.mock_group = mock.MagicMock()
|
|
self.mock_group.world_size = 2
|
|
self.mock_group.rank_in_group = 0
|
|
|
|
parallel_state._MLP_TP = self.mock_group
|
|
parallel_state._OTP = self.mock_group
|
|
|
|
self.mock_ascend_config = MagicMock()
|
|
self.mock_ascend_config.oproj_tensor_parallel_size = 2
|
|
|
|
self.patches = [
|
|
patch("vllm_ascend.ascend_config.get_ascend_config",
|
|
return_value=self.mock_ascend_config),
|
|
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
|
|
return_value=self.mock_group),
|
|
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
|
|
return_value=self.mock_group),
|
|
patch("vllm_ascend.ops.linear_op.get_tp_group",
|
|
return_value=self.mock_group),
|
|
patch(
|
|
"vllm.distributed.parallel_state.get_tp_group",
|
|
return_value=self.mock_group,
|
|
),
|
|
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
|
|
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
|
|
]
|
|
|
|
for p in self.patches:
|
|
p.start()
|
|
|
|
def tearDown(self):
|
|
for p in self.patches:
|
|
p.stop()
|
|
|
|
|
|
class TestAscendUnquantizedLinearMethod(TestBase):
|
|
|
|
def setUp(self):
|
|
self.method = AscendUnquantizedLinearMethod()
|
|
self.layer = mock.MagicMock()
|
|
mock_dtype = mock.PropertyMock(return_value=torch.float16)
|
|
type(self.layer.weight.data).dtype = mock_dtype
|
|
|
|
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
|
@mock.patch("torch_npu.npu_format_cast")
|
|
def test_process_weights_after_loading_enable_nz(self, mock_format_cast,
|
|
mock_is_nz):
|
|
mock_is_nz.return_value = 1
|
|
self.method.process_weights_after_loading(self.layer)
|
|
mock_format_cast.assert_called_once()
|
|
|
|
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
|
@mock.patch("torch_npu.npu_format_cast")
|
|
def test_process_weights_after_loading_disable_nz(self, mock_format_cast,
|
|
mock_is_nz):
|
|
mock_is_nz.return_value = 0
|
|
self.method.process_weights_after_loading(self.layer)
|
|
mock_format_cast.assert_not_called()
|
|
|
|
|
|
class TestAscendRowParallelLinear(BaseLinearTest):
|
|
|
|
def test_mlp_optimize(self):
|
|
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
|
|
|
linear = AscendRowParallelLinear(
|
|
input_size=16,
|
|
output_size=8,
|
|
prefix="down_proj",
|
|
)
|
|
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
|
|
|
|
input_tensor = torch.randn(16, 8)
|
|
linear(input_tensor)
|
|
|
|
def test_oproj_tp(self):
|
|
|
|
config._current_vllm_config = MagicMock()
|
|
|
|
ascend_config._ASCEND_CONFIG = MagicMock()
|
|
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
|
|
|
|
linear = AscendRowParallelLinear(
|
|
input_size=16,
|
|
output_size=8,
|
|
prefix="o_proj",
|
|
)
|
|
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
|
|
|
|
input_tensor = torch.randn(16, 8)
|
|
linear(input_tensor)
|
|
|
|
|
|
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
|
|
|
|
def test_merged_mlp_tp_init(self):
|
|
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
|
|
|
linear = AscendMergedColumnParallelLinear(
|
|
input_size=16,
|
|
output_sizes=[8, 8],
|
|
prefix="gate_up_proj",
|
|
)
|
|
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
|
|
|
|
|
|
class TestAscendReplicatedLinear(BaseLinearTest):
|
|
|
|
def test_init_disable_tp(self):
|
|
linear = AscendReplicatedLinear(
|
|
input_size=16,
|
|
output_size=8,
|
|
)
|
|
self.assertTrue(
|
|
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
|
|
|
|
def test_init_without_disable_tp(self):
|
|
linear = AscendReplicatedLinear(
|
|
input_size=16,
|
|
output_size=8,
|
|
)
|
|
self.assertTrue(
|
|
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|