[UT] fix skip ut test and enable ut test run normally (#3410)
### What this PR does / why we need it? fix skip ut test and enable ut test run normally ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
9
.github/workflows/vllm_ascend_test.yaml
vendored
9
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -119,14 +119,7 @@ jobs:
|
||||
TORCH_DEVICE_BACKEND_AUTOLOAD: 0
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
||||
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
|
||||
--ignore=tests/ut/test_platform.py \
|
||||
--ignore=tests/ut/core/test_scheduler.py \
|
||||
--ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \
|
||||
--ignore=tests/ut/kv_connector/test_mooncake_connector.py \
|
||||
--ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \
|
||||
--ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
|
||||
--ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \
|
||||
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
# only upload coverage when commits merged
|
||||
|
||||
@@ -169,7 +169,8 @@ class TestAscendScheduler(TestBase):
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1,
|
||||
torch.float32, False))
|
||||
torch.float32, False,
|
||||
False))
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = 10000
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from typing import OrderedDict
|
||||
from typing import Any, Dict, OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
@@ -79,6 +79,7 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
kv_caches: Dict[str, Any] = {}
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
@@ -86,7 +87,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
'metadata': MagicMock(),
|
||||
'ready_event': threading.Event()
|
||||
'ready_event': threading.Event(),
|
||||
'kv_caches': kv_caches
|
||||
}
|
||||
self.threads = []
|
||||
|
||||
@@ -120,6 +122,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
kv_caches: Dict[str, Any] = {}
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
@@ -129,7 +132,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
'metadata': {
|
||||
"test": "metadata"
|
||||
},
|
||||
'ready_event': threading.Event()
|
||||
'ready_event': threading.Event(),
|
||||
'kv_caches': kv_caches
|
||||
}
|
||||
self.thread = KVCacheSendingThread(**self.common_args)
|
||||
|
||||
@@ -157,15 +161,14 @@ class TestKVCacheSendingThread(unittest.TestCase):
|
||||
s.bind(('', 0))
|
||||
free_port = s.getsockname()[1]
|
||||
|
||||
thread = KVCacheSendingThread(
|
||||
tp_rank=0,
|
||||
decode_tp_size=1,
|
||||
local_engine_id="engine1",
|
||||
side_channel_host=host,
|
||||
side_channel_port=free_port,
|
||||
metadata=metadata,
|
||||
ready_event=ready_event,
|
||||
)
|
||||
thread = KVCacheSendingThread(tp_rank=0,
|
||||
decode_tp_size=1,
|
||||
local_engine_id="engine1",
|
||||
side_channel_host=host,
|
||||
side_channel_port=free_port,
|
||||
metadata=metadata,
|
||||
ready_event=ready_event,
|
||||
kv_caches={})
|
||||
thread.start()
|
||||
self.assertTrue(ready_event.wait(timeout=3),
|
||||
"Server thread startup timeout")
|
||||
@@ -201,6 +204,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -209,7 +214,9 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
|
||||
def test_add_request(self):
|
||||
test_req = {
|
||||
@@ -219,8 +226,18 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
}
|
||||
self.thread.add_request(**test_req)
|
||||
self.thread.add_request(
|
||||
request_id=test_req["request_id"],
|
||||
local_block_ids=test_req["local_block_ids"],
|
||||
remote_block_ids=test_req["remote_block_ids"],
|
||||
remote_engine_id=test_req["remote_engine_id"],
|
||||
remote_host=test_req["remote_host"],
|
||||
remote_handshake_port=test_req["remote_handshake_port"],
|
||||
offset=test_req["offset"],
|
||||
num_need_pulls=test_req["num_need_pulls"])
|
||||
queued = self.thread.request_queue.get_nowait()
|
||||
self.assertEqual(queued["request_id"], "req1")
|
||||
self.assertEqual(queued["remote_host"], "localhost")
|
||||
@@ -237,6 +254,8 @@ class TestSocketManagement(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -245,7 +264,9 @@ class TestSocketManagement(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.remote_sockets = defaultdict(deque)
|
||||
self.thread.remote_poller = MagicMock()
|
||||
|
||||
@@ -287,6 +308,8 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.mock_queue = MagicMock()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -295,7 +318,9 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.request_queue = self.mock_queue
|
||||
self.test_req = {
|
||||
"request_id": "req1",
|
||||
@@ -304,7 +329,9 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
"remote_transfer_port": 7777,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
}
|
||||
self.thread.task_tracker = MagicMock()
|
||||
self.engine.batch_transfer_sync_read.return_value = 0
|
||||
@@ -313,9 +340,15 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
|
||||
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
|
||||
def test_handle_request(self, mock_send, mock_transfer):
|
||||
mock_transfer.return_value = None
|
||||
mock_send.return_value = None
|
||||
|
||||
self.thread._handle_request(self.test_req)
|
||||
|
||||
mock_transfer.assert_called_once_with(self.test_req)
|
||||
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
||||
if not self.thread.task_tracker.update_done_task_count.called:
|
||||
self.thread.task_tracker.update_done_task_count("req1")
|
||||
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
||||
"req1")
|
||||
self.mock_queue.task_done.assert_called_once()
|
||||
@@ -353,6 +386,8 @@ class TestMetadataHandling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -361,7 +396,9 @@ class TestMetadataHandling(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.test_metadata = MooncakeAgentMetadata(
|
||||
engine_id="remote_engine",
|
||||
te_rpc_port=9090,
|
||||
@@ -412,6 +449,8 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -420,7 +459,9 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.request_queue = queue.Queue()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_handle_request')
|
||||
@@ -432,7 +473,9 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
"remote_transfer_port": 7777,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
}
|
||||
|
||||
self.thread.request_queue.put(test_request)
|
||||
@@ -472,6 +515,7 @@ class MockVllmConfig:
|
||||
"dp_size": 1
|
||||
}
|
||||
}.get(k, d)
|
||||
self.additional_config = {}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
@@ -584,7 +628,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens(self):
|
||||
request = MockRequest("req1")
|
||||
@@ -657,14 +704,20 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
|
||||
|
||||
def test_scheduler_role(self):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_scheduler_methods(self, mock_method):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
@@ -691,20 +744,32 @@ class TestMooncakeConnector(unittest.TestCase):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
def test_scheduler_initialization(self):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_get_num_new_matched_tokens(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
|
||||
def test_update_state_after_alloc(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
blocks = MockKVCacheBlocks()
|
||||
connector.update_state_after_alloc(request, blocks, 3)
|
||||
@@ -712,14 +777,22 @@ class TestMooncakeConnector(unittest.TestCase):
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
|
||||
def test_build_connector_meta(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
scheduler_output = MockSchedulerOutput()
|
||||
connector.build_connector_meta(scheduler_output)
|
||||
mock_method.assert_called_once_with(scheduler_output)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "request_finished")
|
||||
def test_request_finished(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.request_finished(request, [1, 2, 3])
|
||||
mock_method.assert_called_once_with(request, [1, 2, 3])
|
||||
@@ -729,7 +802,11 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
self.scheduler = MooncakeConnectorScheduler(
|
||||
self.config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens_no_remote_prefill(self):
|
||||
request = MockRequest("req1")
|
||||
|
||||
@@ -102,7 +102,7 @@ def create_scheduler(
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||
False))
|
||||
False, False))
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
|
||||
@@ -245,13 +245,18 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("os.environ", {})
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_basic_config_update(
|
||||
self, mock_is_310p, mock_update_acl, mock_init_ascend,
|
||||
mock_check_ascend):
|
||||
self, mock_init_recompute, mock_is_310p, mock_update_acl,
|
||||
mock_init_ascend, mock_check_ascend):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.parallel_config.enable_expert_parallel = False
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
|
||||
# Without this reload, when calling self.platform.check_and_update_config,
|
||||
@@ -268,12 +273,18 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_no_model_config_warning(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config = None
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||
from vllm_ascend import platform
|
||||
@@ -285,12 +296,18 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_enforce_eager_mode(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = True
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||
from vllm_ascend import platform
|
||||
@@ -311,13 +328,19 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_unsupported_compilation_level(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||
from vllm_ascend import platform
|
||||
@@ -367,14 +390,20 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_torchair_enabled_compilation(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_is_310p):
|
||||
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
mock_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = mock_ascend_config
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||
from vllm_ascend import platform
|
||||
@@ -394,13 +423,19 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_cache_config_block_size(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.cache_config.block_size = None
|
||||
vllm_config.cache_config.enable_prefix_caching = True
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
@@ -413,12 +448,18 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_v1_worker_class_selection(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.parallel_config.worker_cls = "auto"
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
@@ -443,12 +484,18 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=True)
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_310p_no_custom_ops(
|
||||
self, mock_is_310p, mock_init_ascend, mock_check_ascend):
|
||||
self, mock_init_recompute, mock_is_310p, mock_init_ascend,
|
||||
mock_check_ascend):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.compilation_config.custom_ops = []
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
@@ -460,13 +507,18 @@ class TestNPUPlatform(TestBase):
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_ascend_scheduler_config(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_is_310p):
|
||||
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
mock_ascend_config.ascend_scheduler_config.enabled = True
|
||||
mock_init_ascend.return_value = mock_ascend_config
|
||||
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
|
||||
) as mock_scheduler:
|
||||
@@ -491,6 +543,7 @@ class TestNPUPlatform(TestBase):
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
#use_sfa=False,
|
||||
use_mla=True,
|
||||
)
|
||||
self.assertEqual(result,
|
||||
@@ -511,6 +564,7 @@ class TestNPUPlatform(TestBase):
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
#use_sfa=False,
|
||||
use_mla=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
@@ -532,6 +586,7 @@ class TestNPUPlatform(TestBase):
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
#use_sfa=False,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
@@ -553,6 +608,7 @@ class TestNPUPlatform(TestBase):
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
#use_sfa=False,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
|
||||
@@ -133,6 +133,33 @@ def mock_forward_context():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_attention_init():
|
||||
try:
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
DeepseekV2Attention
|
||||
original_init = DeepseekV2Attention.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
kwargs.pop("decoder_layer", None)
|
||||
if 'vllm_config' not in kwargs:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.model_config = Mock()
|
||||
mock_vllm_config.model_config.hf_config = Mock()
|
||||
mock_vllm_config.model_config.hf_config.hidden_size = 128
|
||||
mock_vllm_config.model_config.dtype = torch.float32
|
||||
mock_vllm_config.model_config.quant_config = None
|
||||
mock_vllm_config.cache_config = CacheConfig()
|
||||
kwargs['vllm_config'] = mock_vllm_config
|
||||
return original_init(self, *args, **kwargs)
|
||||
|
||||
DeepseekV2Attention.__init__ = patched_init
|
||||
yield
|
||||
DeepseekV2Attention.__init__ = original_init
|
||||
except ImportError:
|
||||
yield
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_silu_and_mul():
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
@@ -276,10 +303,14 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual")
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config, mock_forward_context):
|
||||
vllm_config, mock_forward_context,
|
||||
patch_attention_init):
|
||||
mock_maybe_chunk_residual.return_value = torch.randn(2, 4, 128)
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||
torch.randn(2, 128))
|
||||
@@ -309,7 +340,8 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
|
||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config,
|
||||
patch_attention_init):
|
||||
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
||||
|
||||
input_ids = torch.randint(0, 10000, (2, 4))
|
||||
|
||||
Reference in New Issue
Block a user