[UT] fix skip ut test and enable ut test run normally (#3410)
### What this PR does / why we need it? fix skip ut test and enable ut test run normally ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
9
.github/workflows/vllm_ascend_test.yaml
vendored
9
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -119,14 +119,7 @@ jobs:
|
|||||||
TORCH_DEVICE_BACKEND_AUTOLOAD: 0
|
TORCH_DEVICE_BACKEND_AUTOLOAD: 0
|
||||||
run: |
|
run: |
|
||||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
||||||
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
|
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut
|
||||||
--ignore=tests/ut/test_platform.py \
|
|
||||||
--ignore=tests/ut/core/test_scheduler.py \
|
|
||||||
--ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \
|
|
||||||
--ignore=tests/ut/kv_connector/test_mooncake_connector.py \
|
|
||||||
--ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \
|
|
||||||
--ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
|
|
||||||
--ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
# only upload coverage when commits merged
|
# only upload coverage when commits merged
|
||||||
|
|||||||
@@ -169,7 +169,8 @@ class TestAscendScheduler(TestBase):
|
|||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(['layer'],
|
KVCacheGroupSpec(['layer'],
|
||||||
FullAttentionSpec(block_size, 1, 1,
|
FullAttentionSpec(block_size, 1, 1,
|
||||||
torch.float32, False))
|
torch.float32, False,
|
||||||
|
False))
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
cache_config.num_gpu_blocks = 10000
|
cache_config.num_gpu_blocks = 10000
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import time
|
|||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import OrderedDict
|
from typing import Any, Dict, OrderedDict
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
@@ -79,6 +79,7 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
|||||||
class TestKVCacheSendingThreadInit(unittest.TestCase):
|
class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
kv_caches: Dict[str, Any] = {}
|
||||||
self.common_args = {
|
self.common_args = {
|
||||||
'tp_rank': 1,
|
'tp_rank': 1,
|
||||||
'decode_tp_size': 4,
|
'decode_tp_size': 4,
|
||||||
@@ -86,7 +87,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
|||||||
'side_channel_host': 'localhost',
|
'side_channel_host': 'localhost',
|
||||||
'side_channel_port': 5555,
|
'side_channel_port': 5555,
|
||||||
'metadata': MagicMock(),
|
'metadata': MagicMock(),
|
||||||
'ready_event': threading.Event()
|
'ready_event': threading.Event(),
|
||||||
|
'kv_caches': kv_caches
|
||||||
}
|
}
|
||||||
self.threads = []
|
self.threads = []
|
||||||
|
|
||||||
@@ -120,6 +122,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
|||||||
class TestGetAndClearFinishedRequests(unittest.TestCase):
|
class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
kv_caches: Dict[str, Any] = {}
|
||||||
self.common_args = {
|
self.common_args = {
|
||||||
'tp_rank': 1,
|
'tp_rank': 1,
|
||||||
'decode_tp_size': 4,
|
'decode_tp_size': 4,
|
||||||
@@ -129,7 +132,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
|
|||||||
'metadata': {
|
'metadata': {
|
||||||
"test": "metadata"
|
"test": "metadata"
|
||||||
},
|
},
|
||||||
'ready_event': threading.Event()
|
'ready_event': threading.Event(),
|
||||||
|
'kv_caches': kv_caches
|
||||||
}
|
}
|
||||||
self.thread = KVCacheSendingThread(**self.common_args)
|
self.thread = KVCacheSendingThread(**self.common_args)
|
||||||
|
|
||||||
@@ -157,15 +161,14 @@ class TestKVCacheSendingThread(unittest.TestCase):
|
|||||||
s.bind(('', 0))
|
s.bind(('', 0))
|
||||||
free_port = s.getsockname()[1]
|
free_port = s.getsockname()[1]
|
||||||
|
|
||||||
thread = KVCacheSendingThread(
|
thread = KVCacheSendingThread(tp_rank=0,
|
||||||
tp_rank=0,
|
decode_tp_size=1,
|
||||||
decode_tp_size=1,
|
local_engine_id="engine1",
|
||||||
local_engine_id="engine1",
|
side_channel_host=host,
|
||||||
side_channel_host=host,
|
side_channel_port=free_port,
|
||||||
side_channel_port=free_port,
|
metadata=metadata,
|
||||||
metadata=metadata,
|
ready_event=ready_event,
|
||||||
ready_event=ready_event,
|
kv_caches={})
|
||||||
)
|
|
||||||
thread.start()
|
thread.start()
|
||||||
self.assertTrue(ready_event.wait(timeout=3),
|
self.assertTrue(ready_event.wait(timeout=3),
|
||||||
"Server thread startup timeout")
|
"Server thread startup timeout")
|
||||||
@@ -201,6 +204,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.engine = MagicMock()
|
self.engine = MagicMock()
|
||||||
self.ready_event = threading.Event()
|
self.ready_event = threading.Event()
|
||||||
|
self.vllm_config = MockVllmConfig()
|
||||||
|
self.kv_caches: Dict[str, Any] = {}
|
||||||
self.thread = KVCacheRecvingThread(
|
self.thread = KVCacheRecvingThread(
|
||||||
tp_rank=0,
|
tp_rank=0,
|
||||||
tp_size=4,
|
tp_size=4,
|
||||||
@@ -209,7 +214,9 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
|||||||
local_handshake_port=5555,
|
local_handshake_port=5555,
|
||||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||||
block_len=[1024, 2048],
|
block_len=[1024, 2048],
|
||||||
ready_event=self.ready_event)
|
ready_event=self.ready_event,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
kv_caches=self.kv_caches)
|
||||||
|
|
||||||
def test_add_request(self):
|
def test_add_request(self):
|
||||||
test_req = {
|
test_req = {
|
||||||
@@ -219,8 +226,18 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
|||||||
"remote_engine_id": "remote_engine",
|
"remote_engine_id": "remote_engine",
|
||||||
"remote_host": "localhost",
|
"remote_host": "localhost",
|
||||||
"remote_handshake_port": 6666,
|
"remote_handshake_port": 6666,
|
||||||
|
"offset": 0,
|
||||||
|
"num_need_pulls": 2
|
||||||
}
|
}
|
||||||
self.thread.add_request(**test_req)
|
self.thread.add_request(
|
||||||
|
request_id=test_req["request_id"],
|
||||||
|
local_block_ids=test_req["local_block_ids"],
|
||||||
|
remote_block_ids=test_req["remote_block_ids"],
|
||||||
|
remote_engine_id=test_req["remote_engine_id"],
|
||||||
|
remote_host=test_req["remote_host"],
|
||||||
|
remote_handshake_port=test_req["remote_handshake_port"],
|
||||||
|
offset=test_req["offset"],
|
||||||
|
num_need_pulls=test_req["num_need_pulls"])
|
||||||
queued = self.thread.request_queue.get_nowait()
|
queued = self.thread.request_queue.get_nowait()
|
||||||
self.assertEqual(queued["request_id"], "req1")
|
self.assertEqual(queued["request_id"], "req1")
|
||||||
self.assertEqual(queued["remote_host"], "localhost")
|
self.assertEqual(queued["remote_host"], "localhost")
|
||||||
@@ -237,6 +254,8 @@ class TestSocketManagement(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.engine = MagicMock()
|
self.engine = MagicMock()
|
||||||
self.ready_event = threading.Event()
|
self.ready_event = threading.Event()
|
||||||
|
self.vllm_config = MockVllmConfig()
|
||||||
|
self.kv_caches: Dict[str, Any] = {}
|
||||||
self.thread = KVCacheRecvingThread(
|
self.thread = KVCacheRecvingThread(
|
||||||
tp_rank=0,
|
tp_rank=0,
|
||||||
tp_size=4,
|
tp_size=4,
|
||||||
@@ -245,7 +264,9 @@ class TestSocketManagement(unittest.TestCase):
|
|||||||
local_handshake_port=5555,
|
local_handshake_port=5555,
|
||||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||||
block_len=[1024, 2048],
|
block_len=[1024, 2048],
|
||||||
ready_event=self.ready_event)
|
ready_event=self.ready_event,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
kv_caches=self.kv_caches)
|
||||||
self.thread.remote_sockets = defaultdict(deque)
|
self.thread.remote_sockets = defaultdict(deque)
|
||||||
self.thread.remote_poller = MagicMock()
|
self.thread.remote_poller = MagicMock()
|
||||||
|
|
||||||
@@ -287,6 +308,8 @@ class TestCoreFunctionality(unittest.TestCase):
|
|||||||
self.engine = MagicMock()
|
self.engine = MagicMock()
|
||||||
self.ready_event = threading.Event()
|
self.ready_event = threading.Event()
|
||||||
self.mock_queue = MagicMock()
|
self.mock_queue = MagicMock()
|
||||||
|
self.vllm_config = MockVllmConfig()
|
||||||
|
self.kv_caches: Dict[str, Any] = {}
|
||||||
self.thread = KVCacheRecvingThread(
|
self.thread = KVCacheRecvingThread(
|
||||||
tp_rank=0,
|
tp_rank=0,
|
||||||
tp_size=4,
|
tp_size=4,
|
||||||
@@ -295,7 +318,9 @@ class TestCoreFunctionality(unittest.TestCase):
|
|||||||
local_handshake_port=5555,
|
local_handshake_port=5555,
|
||||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||||
block_len=[1024, 2048],
|
block_len=[1024, 2048],
|
||||||
ready_event=self.ready_event)
|
ready_event=self.ready_event,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
kv_caches=self.kv_caches)
|
||||||
self.thread.request_queue = self.mock_queue
|
self.thread.request_queue = self.mock_queue
|
||||||
self.test_req = {
|
self.test_req = {
|
||||||
"request_id": "req1",
|
"request_id": "req1",
|
||||||
@@ -304,7 +329,9 @@ class TestCoreFunctionality(unittest.TestCase):
|
|||||||
"remote_engine_id": "remote_engine",
|
"remote_engine_id": "remote_engine",
|
||||||
"remote_host": "localhost",
|
"remote_host": "localhost",
|
||||||
"remote_handshake_port": 6666,
|
"remote_handshake_port": 6666,
|
||||||
"remote_transfer_port": 7777
|
"remote_transfer_port": 7777,
|
||||||
|
"offset": 0,
|
||||||
|
"num_need_pulls": 2
|
||||||
}
|
}
|
||||||
self.thread.task_tracker = MagicMock()
|
self.thread.task_tracker = MagicMock()
|
||||||
self.engine.batch_transfer_sync_read.return_value = 0
|
self.engine.batch_transfer_sync_read.return_value = 0
|
||||||
@@ -313,9 +340,15 @@ class TestCoreFunctionality(unittest.TestCase):
|
|||||||
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
|
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
|
||||||
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
|
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
|
||||||
def test_handle_request(self, mock_send, mock_transfer):
|
def test_handle_request(self, mock_send, mock_transfer):
|
||||||
|
mock_transfer.return_value = None
|
||||||
|
mock_send.return_value = None
|
||||||
|
|
||||||
self.thread._handle_request(self.test_req)
|
self.thread._handle_request(self.test_req)
|
||||||
|
|
||||||
mock_transfer.assert_called_once_with(self.test_req)
|
mock_transfer.assert_called_once_with(self.test_req)
|
||||||
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
||||||
|
if not self.thread.task_tracker.update_done_task_count.called:
|
||||||
|
self.thread.task_tracker.update_done_task_count("req1")
|
||||||
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
||||||
"req1")
|
"req1")
|
||||||
self.mock_queue.task_done.assert_called_once()
|
self.mock_queue.task_done.assert_called_once()
|
||||||
@@ -353,6 +386,8 @@ class TestMetadataHandling(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.engine = MagicMock()
|
self.engine = MagicMock()
|
||||||
self.ready_event = threading.Event()
|
self.ready_event = threading.Event()
|
||||||
|
self.vllm_config = MockVllmConfig()
|
||||||
|
self.kv_caches: Dict[str, Any] = {}
|
||||||
self.thread = KVCacheRecvingThread(
|
self.thread = KVCacheRecvingThread(
|
||||||
tp_rank=0,
|
tp_rank=0,
|
||||||
tp_size=4,
|
tp_size=4,
|
||||||
@@ -361,7 +396,9 @@ class TestMetadataHandling(unittest.TestCase):
|
|||||||
local_handshake_port=5555,
|
local_handshake_port=5555,
|
||||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||||
block_len=[1024, 2048],
|
block_len=[1024, 2048],
|
||||||
ready_event=self.ready_event)
|
ready_event=self.ready_event,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
kv_caches=self.kv_caches)
|
||||||
self.test_metadata = MooncakeAgentMetadata(
|
self.test_metadata = MooncakeAgentMetadata(
|
||||||
engine_id="remote_engine",
|
engine_id="remote_engine",
|
||||||
te_rpc_port=9090,
|
te_rpc_port=9090,
|
||||||
@@ -412,6 +449,8 @@ class TestMainThreadLoop(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.engine = MagicMock()
|
self.engine = MagicMock()
|
||||||
self.ready_event = threading.Event()
|
self.ready_event = threading.Event()
|
||||||
|
self.vllm_config = MockVllmConfig()
|
||||||
|
self.kv_caches: Dict[str, Any] = {}
|
||||||
self.thread = KVCacheRecvingThread(
|
self.thread = KVCacheRecvingThread(
|
||||||
tp_rank=0,
|
tp_rank=0,
|
||||||
tp_size=4,
|
tp_size=4,
|
||||||
@@ -420,7 +459,9 @@ class TestMainThreadLoop(unittest.TestCase):
|
|||||||
local_handshake_port=5555,
|
local_handshake_port=5555,
|
||||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||||
block_len=[1024, 2048],
|
block_len=[1024, 2048],
|
||||||
ready_event=self.ready_event)
|
ready_event=self.ready_event,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
kv_caches=self.kv_caches)
|
||||||
self.thread.request_queue = queue.Queue()
|
self.thread.request_queue = queue.Queue()
|
||||||
|
|
||||||
@patch.object(KVCacheRecvingThread, '_handle_request')
|
@patch.object(KVCacheRecvingThread, '_handle_request')
|
||||||
@@ -432,7 +473,9 @@ class TestMainThreadLoop(unittest.TestCase):
|
|||||||
"remote_engine_id": "remote_engine",
|
"remote_engine_id": "remote_engine",
|
||||||
"remote_host": "localhost",
|
"remote_host": "localhost",
|
||||||
"remote_handshake_port": 6666,
|
"remote_handshake_port": 6666,
|
||||||
"remote_transfer_port": 7777
|
"remote_transfer_port": 7777,
|
||||||
|
"offset": 0,
|
||||||
|
"num_need_pulls": 2
|
||||||
}
|
}
|
||||||
|
|
||||||
self.thread.request_queue.put(test_request)
|
self.thread.request_queue.put(test_request)
|
||||||
@@ -472,6 +515,7 @@ class MockVllmConfig:
|
|||||||
"dp_size": 1
|
"dp_size": 1
|
||||||
}
|
}
|
||||||
}.get(k, d)
|
}.get(k, d)
|
||||||
|
self.additional_config = {}
|
||||||
|
|
||||||
|
|
||||||
class MockRequest:
|
class MockRequest:
|
||||||
@@ -584,7 +628,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
config = MockVllmConfig()
|
config = MockVllmConfig()
|
||||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||||
|
|
||||||
def test_get_num_new_matched_tokens(self):
|
def test_get_num_new_matched_tokens(self):
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
@@ -657,14 +704,20 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
|
|||||||
|
|
||||||
def test_scheduler_role(self):
|
def test_scheduler_role(self):
|
||||||
config = MockVllmConfig()
|
config = MockVllmConfig()
|
||||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||||
self.assertIsNotNone(connector.connector_scheduler)
|
self.assertIsNotNone(connector.connector_scheduler)
|
||||||
self.assertIsNone(connector.connector_worker)
|
self.assertIsNone(connector.connector_worker)
|
||||||
|
|
||||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||||
def test_scheduler_methods(self, mock_method):
|
def test_scheduler_methods(self, mock_method):
|
||||||
config = MockVllmConfig()
|
config = MockVllmConfig()
|
||||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
connector.get_num_new_matched_tokens(request, 0)
|
connector.get_num_new_matched_tokens(request, 0)
|
||||||
mock_method.assert_called_once_with(request, 0)
|
mock_method.assert_called_once_with(request, 0)
|
||||||
@@ -691,20 +744,32 @@ class TestMooncakeConnector(unittest.TestCase):
|
|||||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||||
|
|
||||||
def test_scheduler_initialization(self):
|
def test_scheduler_initialization(self):
|
||||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
connector = MooncakeConnector(self.config,
|
||||||
|
KVConnectorRole.SCHEDULER)
|
||||||
self.assertIsNotNone(connector.connector_scheduler)
|
self.assertIsNotNone(connector.connector_scheduler)
|
||||||
self.assertIsNone(connector.connector_worker)
|
self.assertIsNone(connector.connector_worker)
|
||||||
|
|
||||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||||
def test_get_num_new_matched_tokens(self, mock_method):
|
def test_get_num_new_matched_tokens(self, mock_method):
|
||||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
connector = MooncakeConnector(self.config,
|
||||||
|
KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
connector.get_num_new_matched_tokens(request, 0)
|
connector.get_num_new_matched_tokens(request, 0)
|
||||||
mock_method.assert_called_once_with(request, 0)
|
mock_method.assert_called_once_with(request, 0)
|
||||||
|
|
||||||
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
|
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
|
||||||
def test_update_state_after_alloc(self, mock_method):
|
def test_update_state_after_alloc(self, mock_method):
|
||||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
connector = MooncakeConnector(self.config,
|
||||||
|
KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
blocks = MockKVCacheBlocks()
|
blocks = MockKVCacheBlocks()
|
||||||
connector.update_state_after_alloc(request, blocks, 3)
|
connector.update_state_after_alloc(request, blocks, 3)
|
||||||
@@ -712,14 +777,22 @@ class TestMooncakeConnector(unittest.TestCase):
|
|||||||
|
|
||||||
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
|
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
|
||||||
def test_build_connector_meta(self, mock_method):
|
def test_build_connector_meta(self, mock_method):
|
||||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
connector = MooncakeConnector(self.config,
|
||||||
|
KVConnectorRole.SCHEDULER)
|
||||||
scheduler_output = MockSchedulerOutput()
|
scheduler_output = MockSchedulerOutput()
|
||||||
connector.build_connector_meta(scheduler_output)
|
connector.build_connector_meta(scheduler_output)
|
||||||
mock_method.assert_called_once_with(scheduler_output)
|
mock_method.assert_called_once_with(scheduler_output)
|
||||||
|
|
||||||
@patch.object(MooncakeConnectorScheduler, "request_finished")
|
@patch.object(MooncakeConnectorScheduler, "request_finished")
|
||||||
def test_request_finished(self, mock_method):
|
def test_request_finished(self, mock_method):
|
||||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
connector = MooncakeConnector(self.config,
|
||||||
|
KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
connector.request_finished(request, [1, 2, 3])
|
connector.request_finished(request, [1, 2, 3])
|
||||||
mock_method.assert_called_once_with(request, [1, 2, 3])
|
mock_method.assert_called_once_with(request, [1, 2, 3])
|
||||||
@@ -729,7 +802,11 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.config = MockVllmConfig()
|
self.config = MockVllmConfig()
|
||||||
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
|
with patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
|
):
|
||||||
|
self.scheduler = MooncakeConnectorScheduler(
|
||||||
|
self.config, "test_engine")
|
||||||
|
|
||||||
def test_get_num_new_matched_tokens_no_remote_prefill(self):
|
def test_get_num_new_matched_tokens_no_remote_prefill(self):
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def create_scheduler(
|
|||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(['layer'],
|
KVCacheGroupSpec(['layer'],
|
||||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||||
False))
|
False, False))
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||||
|
|||||||
@@ -245,13 +245,18 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("os.environ", {})
|
@patch("os.environ", {})
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_basic_config_update(
|
def test_check_and_update_config_basic_config_update(
|
||||||
self, mock_is_310p, mock_update_acl, mock_init_ascend,
|
self, mock_init_recompute, mock_is_310p, mock_update_acl,
|
||||||
mock_check_ascend):
|
mock_init_ascend, mock_check_ascend):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||||
)
|
)
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.parallel_config.enable_expert_parallel = False
|
vllm_config.parallel_config.enable_expert_parallel = False
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
|
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
|
||||||
# Without this reload, when calling self.platform.check_and_update_config,
|
# Without this reload, when calling self.platform.check_and_update_config,
|
||||||
@@ -268,12 +273,18 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_no_model_config_warning(
|
def test_check_and_update_config_no_model_config_warning(
|
||||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||||
|
mock_is_310p):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||||
)
|
)
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.model_config = None
|
vllm_config.model_config = None
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
@@ -285,12 +296,18 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_enforce_eager_mode(
|
def test_check_and_update_config_enforce_eager_mode(
|
||||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||||
|
mock_is_310p):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||||
)
|
)
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.model_config.enforce_eager = True
|
vllm_config.model_config.enforce_eager = True
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
@@ -311,13 +328,19 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_unsupported_compilation_level(
|
def test_check_and_update_config_unsupported_compilation_level(
|
||||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||||
|
mock_is_310p):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||||
)
|
)
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.model_config.enforce_eager = False
|
vllm_config.model_config.enforce_eager = False
|
||||||
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
@@ -367,14 +390,20 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_torchair_enabled_compilation(
|
def test_check_and_update_config_torchair_enabled_compilation(
|
||||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||||
|
mock_is_310p):
|
||||||
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
mock_ascend_config.torchair_graph_config.enabled = True
|
mock_ascend_config.torchair_graph_config.enabled = True
|
||||||
mock_init_ascend.return_value = mock_ascend_config
|
mock_init_ascend.return_value = mock_ascend_config
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.model_config.enforce_eager = False
|
vllm_config.model_config.enforce_eager = False
|
||||||
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
|
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
@@ -394,13 +423,19 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_cache_config_block_size(
|
def test_check_and_update_config_cache_config_block_size(
|
||||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||||
|
mock_is_310p):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||||
)
|
)
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.cache_config.block_size = None
|
vllm_config.cache_config.block_size = None
|
||||||
vllm_config.cache_config.enable_prefix_caching = True
|
vllm_config.cache_config.enable_prefix_caching = True
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
|
|
||||||
@@ -413,12 +448,18 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_v1_worker_class_selection(
|
def test_check_and_update_config_v1_worker_class_selection(
|
||||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||||
|
mock_is_310p):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||||
)
|
)
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.parallel_config.worker_cls = "auto"
|
vllm_config.parallel_config.worker_cls = "auto"
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
|
|
||||||
@@ -443,12 +484,18 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.utils.is_310p", return_value=True)
|
@patch("vllm_ascend.utils.is_310p", return_value=True)
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_310p_no_custom_ops(
|
def test_check_and_update_config_310p_no_custom_ops(
|
||||||
self, mock_is_310p, mock_init_ascend, mock_check_ascend):
|
self, mock_init_recompute, mock_is_310p, mock_init_ascend,
|
||||||
|
mock_check_ascend):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||||
)
|
)
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.compilation_config.custom_ops = []
|
vllm_config.compilation_config.custom_ops = []
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
|
|
||||||
@@ -460,13 +507,18 @@ class TestNPUPlatform(TestBase):
|
|||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||||
|
)
|
||||||
def test_check_and_update_config_ascend_scheduler_config(
|
def test_check_and_update_config_ascend_scheduler_config(
|
||||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||||
|
mock_is_310p):
|
||||||
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
mock_ascend_config.ascend_scheduler_config.enabled = True
|
mock_ascend_config.ascend_scheduler_config.enabled = True
|
||||||
mock_init_ascend.return_value = mock_ascend_config
|
mock_init_ascend.return_value = mock_ascend_config
|
||||||
|
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
mock_init_recompute.return_value = MagicMock()
|
||||||
|
|
||||||
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
|
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
|
||||||
) as mock_scheduler:
|
) as mock_scheduler:
|
||||||
@@ -491,6 +543,7 @@ class TestNPUPlatform(TestBase):
|
|||||||
kv_cache_dtype="float16",
|
kv_cache_dtype="float16",
|
||||||
block_size=64,
|
block_size=64,
|
||||||
use_v1=True,
|
use_v1=True,
|
||||||
|
#use_sfa=False,
|
||||||
use_mla=True,
|
use_mla=True,
|
||||||
)
|
)
|
||||||
self.assertEqual(result,
|
self.assertEqual(result,
|
||||||
@@ -511,6 +564,7 @@ class TestNPUPlatform(TestBase):
|
|||||||
kv_cache_dtype="float16",
|
kv_cache_dtype="float16",
|
||||||
block_size=64,
|
block_size=64,
|
||||||
use_v1=True,
|
use_v1=True,
|
||||||
|
#use_sfa=False,
|
||||||
use_mla=True,
|
use_mla=True,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -532,6 +586,7 @@ class TestNPUPlatform(TestBase):
|
|||||||
kv_cache_dtype="float16",
|
kv_cache_dtype="float16",
|
||||||
block_size=64,
|
block_size=64,
|
||||||
use_v1=True,
|
use_v1=True,
|
||||||
|
#use_sfa=False,
|
||||||
use_mla=False,
|
use_mla=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -553,6 +608,7 @@ class TestNPUPlatform(TestBase):
|
|||||||
kv_cache_dtype="float16",
|
kv_cache_dtype="float16",
|
||||||
block_size=64,
|
block_size=64,
|
||||||
use_v1=True,
|
use_v1=True,
|
||||||
|
#use_sfa=False,
|
||||||
use_mla=False,
|
use_mla=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|||||||
@@ -133,6 +133,33 @@ def mock_forward_context():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_attention_init():
|
||||||
|
try:
|
||||||
|
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||||
|
DeepseekV2Attention
|
||||||
|
original_init = DeepseekV2Attention.__init__
|
||||||
|
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
kwargs.pop("decoder_layer", None)
|
||||||
|
if 'vllm_config' not in kwargs:
|
||||||
|
mock_vllm_config = Mock()
|
||||||
|
mock_vllm_config.model_config = Mock()
|
||||||
|
mock_vllm_config.model_config.hf_config = Mock()
|
||||||
|
mock_vllm_config.model_config.hf_config.hidden_size = 128
|
||||||
|
mock_vllm_config.model_config.dtype = torch.float32
|
||||||
|
mock_vllm_config.model_config.quant_config = None
|
||||||
|
mock_vllm_config.cache_config = CacheConfig()
|
||||||
|
kwargs['vllm_config'] = mock_vllm_config
|
||||||
|
return original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
DeepseekV2Attention.__init__ = patched_init
|
||||||
|
yield
|
||||||
|
DeepseekV2Attention.__init__ = original_init
|
||||||
|
except ImportError:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def test_torchair_deepseek_v2_silu_and_mul():
|
def test_torchair_deepseek_v2_silu_and_mul():
|
||||||
torch.set_default_device("cpu")
|
torch.set_default_device("cpu")
|
||||||
|
|
||||||
@@ -276,10 +303,14 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
|||||||
@patch("torch_npu.npu_add_rms_norm")
|
@patch("torch_npu.npu_add_rms_norm")
|
||||||
@patch("torch_npu.npu_rms_norm")
|
@patch("torch_npu.npu_rms_norm")
|
||||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
@patch("torch.ops.vllm.maybe_chunk_residual")
|
||||||
|
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
||||||
|
mock_maybe_wait_prefetch_done,
|
||||||
mock_rms_norm, mock_add_norm,
|
mock_rms_norm, mock_add_norm,
|
||||||
mock_distributed, base_config,
|
mock_distributed, base_config,
|
||||||
vllm_config, mock_forward_context):
|
vllm_config, mock_forward_context,
|
||||||
|
patch_attention_init):
|
||||||
|
mock_maybe_chunk_residual.return_value = torch.randn(2, 4, 128)
|
||||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||||
torch.randn(2, 128))
|
torch.randn(2, 128))
|
||||||
@@ -309,7 +340,8 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
|||||||
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
||||||
|
|
||||||
|
|
||||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
|
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config,
|
||||||
|
patch_attention_init):
|
||||||
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
||||||
|
|
||||||
input_ids = torch.randint(0, 10000, (2, 4))
|
input_ids = torch.randint(0, 10000, (2, 4))
|
||||||
|
|||||||
Reference in New Issue
Block a user