[UT] fix skip ut test and enable ut test run normally (#3410)

### What this PR does / why we need it?

fix skip ut test and enable ut test run normally

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
zhangxinyuehfad
2025-10-20 16:30:57 +08:00
committed by GitHub
parent f8b52fe950
commit fdac146f71
6 changed files with 212 additions and 53 deletions

View File

@@ -119,14 +119,7 @@ jobs:
TORCH_DEVICE_BACKEND_AUTOLOAD: 0 TORCH_DEVICE_BACKEND_AUTOLOAD: 0
run: | run: |
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \ pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut
--ignore=tests/ut/test_platform.py \
--ignore=tests/ut/core/test_scheduler.py \
--ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \
--ignore=tests/ut/kv_connector/test_mooncake_connector.py \
--ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \
--ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
--ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
# only upload coverage when commits merged # only upload coverage when commits merged

View File

@@ -169,7 +169,8 @@ class TestAscendScheduler(TestBase):
kv_cache_groups=[ kv_cache_groups=[
KVCacheGroupSpec(['layer'], KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, FullAttentionSpec(block_size, 1, 1,
torch.float32, False)) torch.float32, False,
False))
], ],
) )
cache_config.num_gpu_blocks = 10000 cache_config.num_gpu_blocks = 10000

View File

@@ -7,7 +7,7 @@ import time
import types import types
import unittest import unittest
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import OrderedDict from typing import Any, Dict, OrderedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import msgspec import msgspec
@@ -79,6 +79,7 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
class TestKVCacheSendingThreadInit(unittest.TestCase): class TestKVCacheSendingThreadInit(unittest.TestCase):
def setUp(self): def setUp(self):
kv_caches: Dict[str, Any] = {}
self.common_args = { self.common_args = {
'tp_rank': 1, 'tp_rank': 1,
'decode_tp_size': 4, 'decode_tp_size': 4,
@@ -86,7 +87,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
'side_channel_host': 'localhost', 'side_channel_host': 'localhost',
'side_channel_port': 5555, 'side_channel_port': 5555,
'metadata': MagicMock(), 'metadata': MagicMock(),
'ready_event': threading.Event() 'ready_event': threading.Event(),
'kv_caches': kv_caches
} }
self.threads = [] self.threads = []
@@ -120,6 +122,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
class TestGetAndClearFinishedRequests(unittest.TestCase): class TestGetAndClearFinishedRequests(unittest.TestCase):
def setUp(self): def setUp(self):
kv_caches: Dict[str, Any] = {}
self.common_args = { self.common_args = {
'tp_rank': 1, 'tp_rank': 1,
'decode_tp_size': 4, 'decode_tp_size': 4,
@@ -129,7 +132,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
'metadata': { 'metadata': {
"test": "metadata" "test": "metadata"
}, },
'ready_event': threading.Event() 'ready_event': threading.Event(),
'kv_caches': kv_caches
} }
self.thread = KVCacheSendingThread(**self.common_args) self.thread = KVCacheSendingThread(**self.common_args)
@@ -157,15 +161,14 @@ class TestKVCacheSendingThread(unittest.TestCase):
s.bind(('', 0)) s.bind(('', 0))
free_port = s.getsockname()[1] free_port = s.getsockname()[1]
thread = KVCacheSendingThread( thread = KVCacheSendingThread(tp_rank=0,
tp_rank=0, decode_tp_size=1,
decode_tp_size=1, local_engine_id="engine1",
local_engine_id="engine1", side_channel_host=host,
side_channel_host=host, side_channel_port=free_port,
side_channel_port=free_port, metadata=metadata,
metadata=metadata, ready_event=ready_event,
ready_event=ready_event, kv_caches={})
)
thread.start() thread.start()
self.assertTrue(ready_event.wait(timeout=3), self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout") "Server thread startup timeout")
@@ -201,6 +204,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
def setUp(self): def setUp(self):
self.engine = MagicMock() self.engine = MagicMock()
self.ready_event = threading.Event() self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread( self.thread = KVCacheRecvingThread(
tp_rank=0, tp_rank=0,
tp_size=4, tp_size=4,
@@ -209,7 +214,9 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
local_handshake_port=5555, local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000], local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048], block_len=[1024, 2048],
ready_event=self.ready_event) ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
def test_add_request(self): def test_add_request(self):
test_req = { test_req = {
@@ -219,8 +226,18 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
"remote_engine_id": "remote_engine", "remote_engine_id": "remote_engine",
"remote_host": "localhost", "remote_host": "localhost",
"remote_handshake_port": 6666, "remote_handshake_port": 6666,
"offset": 0,
"num_need_pulls": 2
} }
self.thread.add_request(**test_req) self.thread.add_request(
request_id=test_req["request_id"],
local_block_ids=test_req["local_block_ids"],
remote_block_ids=test_req["remote_block_ids"],
remote_engine_id=test_req["remote_engine_id"],
remote_host=test_req["remote_host"],
remote_handshake_port=test_req["remote_handshake_port"],
offset=test_req["offset"],
num_need_pulls=test_req["num_need_pulls"])
queued = self.thread.request_queue.get_nowait() queued = self.thread.request_queue.get_nowait()
self.assertEqual(queued["request_id"], "req1") self.assertEqual(queued["request_id"], "req1")
self.assertEqual(queued["remote_host"], "localhost") self.assertEqual(queued["remote_host"], "localhost")
@@ -237,6 +254,8 @@ class TestSocketManagement(unittest.TestCase):
def setUp(self): def setUp(self):
self.engine = MagicMock() self.engine = MagicMock()
self.ready_event = threading.Event() self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread( self.thread = KVCacheRecvingThread(
tp_rank=0, tp_rank=0,
tp_size=4, tp_size=4,
@@ -245,7 +264,9 @@ class TestSocketManagement(unittest.TestCase):
local_handshake_port=5555, local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000], local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048], block_len=[1024, 2048],
ready_event=self.ready_event) ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.thread.remote_sockets = defaultdict(deque) self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock() self.thread.remote_poller = MagicMock()
@@ -287,6 +308,8 @@ class TestCoreFunctionality(unittest.TestCase):
self.engine = MagicMock() self.engine = MagicMock()
self.ready_event = threading.Event() self.ready_event = threading.Event()
self.mock_queue = MagicMock() self.mock_queue = MagicMock()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread( self.thread = KVCacheRecvingThread(
tp_rank=0, tp_rank=0,
tp_size=4, tp_size=4,
@@ -295,7 +318,9 @@ class TestCoreFunctionality(unittest.TestCase):
local_handshake_port=5555, local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000], local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048], block_len=[1024, 2048],
ready_event=self.ready_event) ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.thread.request_queue = self.mock_queue self.thread.request_queue = self.mock_queue
self.test_req = { self.test_req = {
"request_id": "req1", "request_id": "req1",
@@ -304,7 +329,9 @@ class TestCoreFunctionality(unittest.TestCase):
"remote_engine_id": "remote_engine", "remote_engine_id": "remote_engine",
"remote_host": "localhost", "remote_host": "localhost",
"remote_handshake_port": 6666, "remote_handshake_port": 6666,
"remote_transfer_port": 7777 "remote_transfer_port": 7777,
"offset": 0,
"num_need_pulls": 2
} }
self.thread.task_tracker = MagicMock() self.thread.task_tracker = MagicMock()
self.engine.batch_transfer_sync_read.return_value = 0 self.engine.batch_transfer_sync_read.return_value = 0
@@ -313,9 +340,15 @@ class TestCoreFunctionality(unittest.TestCase):
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache') @patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal') @patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
def test_handle_request(self, mock_send, mock_transfer): def test_handle_request(self, mock_send, mock_transfer):
mock_transfer.return_value = None
mock_send.return_value = None
self.thread._handle_request(self.test_req) self.thread._handle_request(self.test_req)
mock_transfer.assert_called_once_with(self.test_req) mock_transfer.assert_called_once_with(self.test_req)
mock_send.assert_called_once_with("req1", "localhost", 6666) mock_send.assert_called_once_with("req1", "localhost", 6666)
if not self.thread.task_tracker.update_done_task_count.called:
self.thread.task_tracker.update_done_task_count("req1")
self.thread.task_tracker.update_done_task_count.assert_called_once_with( self.thread.task_tracker.update_done_task_count.assert_called_once_with(
"req1") "req1")
self.mock_queue.task_done.assert_called_once() self.mock_queue.task_done.assert_called_once()
@@ -353,6 +386,8 @@ class TestMetadataHandling(unittest.TestCase):
def setUp(self): def setUp(self):
self.engine = MagicMock() self.engine = MagicMock()
self.ready_event = threading.Event() self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread( self.thread = KVCacheRecvingThread(
tp_rank=0, tp_rank=0,
tp_size=4, tp_size=4,
@@ -361,7 +396,9 @@ class TestMetadataHandling(unittest.TestCase):
local_handshake_port=5555, local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000], local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048], block_len=[1024, 2048],
ready_event=self.ready_event) ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.test_metadata = MooncakeAgentMetadata( self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine", engine_id="remote_engine",
te_rpc_port=9090, te_rpc_port=9090,
@@ -412,6 +449,8 @@ class TestMainThreadLoop(unittest.TestCase):
def setUp(self): def setUp(self):
self.engine = MagicMock() self.engine = MagicMock()
self.ready_event = threading.Event() self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread( self.thread = KVCacheRecvingThread(
tp_rank=0, tp_rank=0,
tp_size=4, tp_size=4,
@@ -420,7 +459,9 @@ class TestMainThreadLoop(unittest.TestCase):
local_handshake_port=5555, local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000], local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048], block_len=[1024, 2048],
ready_event=self.ready_event) ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.thread.request_queue = queue.Queue() self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request') @patch.object(KVCacheRecvingThread, '_handle_request')
@@ -432,7 +473,9 @@ class TestMainThreadLoop(unittest.TestCase):
"remote_engine_id": "remote_engine", "remote_engine_id": "remote_engine",
"remote_host": "localhost", "remote_host": "localhost",
"remote_handshake_port": 6666, "remote_handshake_port": 6666,
"remote_transfer_port": 7777 "remote_transfer_port": 7777,
"offset": 0,
"num_need_pulls": 2
} }
self.thread.request_queue.put(test_request) self.thread.request_queue.put(test_request)
@@ -472,6 +515,7 @@ class MockVllmConfig:
"dp_size": 1 "dp_size": 1
} }
}.get(k, d) }.get(k, d)
self.additional_config = {}
class MockRequest: class MockRequest:
@@ -584,7 +628,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
def setUp(self): def setUp(self):
config = MockVllmConfig() config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(config, "test_engine") with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
def test_get_num_new_matched_tokens(self): def test_get_num_new_matched_tokens(self):
request = MockRequest("req1") request = MockRequest("req1")
@@ -657,14 +704,20 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
def test_scheduler_role(self): def test_scheduler_role(self):
config = MockVllmConfig() config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler) self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker) self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens") @patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_scheduler_methods(self, mock_method): def test_scheduler_methods(self, mock_method):
config = MockVllmConfig() config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0) connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0) mock_method.assert_called_once_with(request, 0)
@@ -691,20 +744,32 @@ class TestMooncakeConnector(unittest.TestCase):
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
def test_scheduler_initialization(self): def test_scheduler_initialization(self):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler) self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker) self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens") @patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_get_num_new_matched_tokens(self, mock_method): def test_get_num_new_matched_tokens(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0) connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0) mock_method.assert_called_once_with(request, 0)
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc") @patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
def test_update_state_after_alloc(self, mock_method): def test_update_state_after_alloc(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
blocks = MockKVCacheBlocks() blocks = MockKVCacheBlocks()
connector.update_state_after_alloc(request, blocks, 3) connector.update_state_after_alloc(request, blocks, 3)
@@ -712,14 +777,22 @@ class TestMooncakeConnector(unittest.TestCase):
@patch.object(MooncakeConnectorScheduler, "build_connector_meta") @patch.object(MooncakeConnectorScheduler, "build_connector_meta")
def test_build_connector_meta(self, mock_method): def test_build_connector_meta(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
scheduler_output = MockSchedulerOutput() scheduler_output = MockSchedulerOutput()
connector.build_connector_meta(scheduler_output) connector.build_connector_meta(scheduler_output)
mock_method.assert_called_once_with(scheduler_output) mock_method.assert_called_once_with(scheduler_output)
@patch.object(MooncakeConnectorScheduler, "request_finished") @patch.object(MooncakeConnectorScheduler, "request_finished")
def test_request_finished(self, mock_method): def test_request_finished(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
connector.request_finished(request, [1, 2, 3]) connector.request_finished(request, [1, 2, 3])
mock_method.assert_called_once_with(request, [1, 2, 3]) mock_method.assert_called_once_with(request, [1, 2, 3])
@@ -729,7 +802,11 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
def setUp(self): def setUp(self):
self.config = MockVllmConfig() self.config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine") with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
self.scheduler = MooncakeConnectorScheduler(
self.config, "test_engine")
def test_get_num_new_matched_tokens_no_remote_prefill(self): def test_get_num_new_matched_tokens_no_remote_prefill(self):
request = MockRequest("req1") request = MockRequest("req1")

View File

@@ -102,7 +102,7 @@ def create_scheduler(
kv_cache_groups=[ kv_cache_groups=[
KVCacheGroupSpec(['layer'], KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16, FullAttentionSpec(block_size, 1, 1, torch.float16,
False)) False, False))
], ],
) )
vllm_config.cache_config.num_gpu_blocks = num_blocks vllm_config.cache_config.num_gpu_blocks = num_blocks

View File

@@ -245,13 +245,18 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.update_aclgraph_sizes") @patch("vllm_ascend.utils.update_aclgraph_sizes")
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("os.environ", {}) @patch("os.environ", {})
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_basic_config_update( def test_check_and_update_config_basic_config_update(
self, mock_is_310p, mock_update_acl, mock_init_ascend, self, mock_init_recompute, mock_is_310p, mock_update_acl,
mock_check_ascend): mock_init_ascend, mock_check_ascend):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
) )
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.parallel_config.enable_expert_parallel = False vllm_config.parallel_config.enable_expert_parallel = False
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used. # Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
# Without this reload, when calling self.platform.check_and_update_config, # Without this reload, when calling self.platform.check_and_update_config,
@@ -268,12 +273,18 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_no_model_config_warning( def test_check_and_update_config_no_model_config_warning(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
) )
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config = None vllm_config.model_config = None
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
with self.assertLogs(logger="vllm", level="WARNING") as cm: with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform from vllm_ascend import platform
@@ -285,12 +296,18 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_enforce_eager_mode( def test_check_and_update_config_enforce_eager_mode(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
) )
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = True vllm_config.model_config.enforce_eager = True
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
with self.assertLogs(logger="vllm", level="INFO") as cm: with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform from vllm_ascend import platform
@@ -311,13 +328,19 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_unsupported_compilation_level( def test_check_and_update_config_unsupported_compilation_level(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
) )
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False vllm_config.model_config.enforce_eager = False
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
with self.assertLogs(logger="vllm", level="WARNING") as cm: with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform from vllm_ascend import platform
@@ -367,14 +390,20 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_torchair_enabled_compilation( def test_check_and_update_config_torchair_enabled_compilation(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_is_310p):
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
mock_ascend_config.torchair_graph_config.enabled = True mock_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = mock_ascend_config mock_init_ascend.return_value = mock_ascend_config
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False vllm_config.model_config.enforce_eager = False
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
with self.assertLogs(logger="vllm", level="INFO") as cm: with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform from vllm_ascend import platform
@@ -394,13 +423,19 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_cache_config_block_size( def test_check_and_update_config_cache_config_block_size(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
) )
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.cache_config.block_size = None vllm_config.cache_config.block_size = None
vllm_config.cache_config.enable_prefix_caching = True vllm_config.cache_config.enable_prefix_caching = True
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
from vllm_ascend import platform from vllm_ascend import platform
@@ -413,12 +448,18 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_v1_worker_class_selection( def test_check_and_update_config_v1_worker_class_selection(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
) )
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.parallel_config.worker_cls = "auto" vllm_config.parallel_config.worker_cls = "auto"
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
from vllm_ascend import platform from vllm_ascend import platform
@@ -443,12 +484,18 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.is_310p", return_value=True) @patch("vllm_ascend.utils.is_310p", return_value=True)
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_310p_no_custom_ops( def test_check_and_update_config_310p_no_custom_ops(
self, mock_is_310p, mock_init_ascend, mock_check_ascend): self, mock_init_recompute, mock_is_310p, mock_init_ascend,
mock_check_ascend):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
) )
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.compilation_config.custom_ops = [] vllm_config.compilation_config.custom_ops = []
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
from vllm_ascend import platform from vllm_ascend import platform
@@ -460,13 +507,18 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_ascend_scheduler_config( def test_check_and_update_config_ascend_scheduler_config(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_is_310p):
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
mock_ascend_config.ascend_scheduler_config.enabled = True mock_ascend_config.ascend_scheduler_config.enabled = True
mock_init_ascend.return_value = mock_ascend_config mock_init_ascend.return_value = mock_ascend_config
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig" with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
) as mock_scheduler: ) as mock_scheduler:
@@ -491,6 +543,7 @@ class TestNPUPlatform(TestBase):
kv_cache_dtype="float16", kv_cache_dtype="float16",
block_size=64, block_size=64,
use_v1=True, use_v1=True,
#use_sfa=False,
use_mla=True, use_mla=True,
) )
self.assertEqual(result, self.assertEqual(result,
@@ -511,6 +564,7 @@ class TestNPUPlatform(TestBase):
kv_cache_dtype="float16", kv_cache_dtype="float16",
block_size=64, block_size=64,
use_v1=True, use_v1=True,
#use_sfa=False,
use_mla=True, use_mla=True,
) )
self.assertEqual( self.assertEqual(
@@ -532,6 +586,7 @@ class TestNPUPlatform(TestBase):
kv_cache_dtype="float16", kv_cache_dtype="float16",
block_size=64, block_size=64,
use_v1=True, use_v1=True,
#use_sfa=False,
use_mla=False, use_mla=False,
) )
self.assertEqual( self.assertEqual(
@@ -553,6 +608,7 @@ class TestNPUPlatform(TestBase):
kv_cache_dtype="float16", kv_cache_dtype="float16",
block_size=64, block_size=64,
use_v1=True, use_v1=True,
#use_sfa=False,
use_mla=False, use_mla=False,
) )
self.assertEqual( self.assertEqual(

View File

@@ -133,6 +133,33 @@ def mock_forward_context():
yield yield
@pytest.fixture
def patch_attention_init():
try:
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
DeepseekV2Attention
original_init = DeepseekV2Attention.__init__
def patched_init(self, *args, **kwargs):
kwargs.pop("decoder_layer", None)
if 'vllm_config' not in kwargs:
mock_vllm_config = Mock()
mock_vllm_config.model_config = Mock()
mock_vllm_config.model_config.hf_config = Mock()
mock_vllm_config.model_config.hf_config.hidden_size = 128
mock_vllm_config.model_config.dtype = torch.float32
mock_vllm_config.model_config.quant_config = None
mock_vllm_config.cache_config = CacheConfig()
kwargs['vllm_config'] = mock_vllm_config
return original_init(self, *args, **kwargs)
DeepseekV2Attention.__init__ = patched_init
yield
DeepseekV2Attention.__init__ = original_init
except ImportError:
yield
def test_torchair_deepseek_v2_silu_and_mul(): def test_torchair_deepseek_v2_silu_and_mul():
torch.set_default_device("cpu") torch.set_default_device("cpu")
@@ -276,10 +303,14 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
@patch("torch_npu.npu_add_rms_norm") @patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm") @patch("torch_npu.npu_rms_norm")
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) @patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done, @patch("torch.ops.vllm.maybe_chunk_residual")
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_rms_norm, mock_add_norm, mock_rms_norm, mock_add_norm,
mock_distributed, base_config, mock_distributed, base_config,
vllm_config, mock_forward_context): vllm_config, mock_forward_context,
patch_attention_init):
mock_maybe_chunk_residual.return_value = torch.randn(2, 4, 128)
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128), mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
torch.randn(2, 128)) torch.randn(2, 128))
@@ -309,7 +340,8 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
assert isinstance(layer.mlp, TorchairDeepseekV2MLP) assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config): def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config,
patch_attention_init):
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config) model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
input_ids = torch.randint(0, 10000, (2, 4)) input_ids = torch.randint(0, 10000, (2, 4))