[refactor](UT,PCP,DCP) refactor pcp&dcp patches in UTs (#5505)
### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -12,7 +12,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.utils.network_utils import make_zmq_path
|
||||
|
||||
fake_engine = types.ModuleType("mooncake.engine")
|
||||
@@ -23,6 +22,7 @@ _mock_ascend_config = MagicMock(enable_kv_nz=False)
|
||||
_mock_pp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
_mock_tp_group = MagicMock(rank_in_group=0, world_size=4)
|
||||
_mock_pcp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
_mock_dcp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
|
||||
return_value=_mock_pp_group).start()
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
|
||||
@@ -35,6 +35,7 @@ patch(
|
||||
return_value=0).start()
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
|
||||
return_value=_mock_pcp_group).start()
|
||||
patch('vllm.distributed.parallel_state._DCP', _mock_dcp_group).start()
|
||||
|
||||
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
|
||||
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
|
||||
@@ -1098,17 +1099,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
self.mock_transfer_engine.get_rpc_port.return_value = 9090
|
||||
self.mock_transfer_engine.initialize.return_value = 0
|
||||
self.mock_transfer_engine.register_memory.return_value = 0
|
||||
self.mock_dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
self.mock_dcp_group.rank_in_group = 0
|
||||
self.mock_dcp_group.world_size = 1
|
||||
self.mock_dcp_group.device_group = MagicMock()
|
||||
self.mock_dcp = MagicMock()
|
||||
self.mock_dcp.world_size = 1
|
||||
|
||||
self.mock_pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
self.mock_pcp_group.rank_in_group = 0
|
||||
self.mock_pcp_group.world_size = 1
|
||||
self.mock_pcp_group.device_group = MagicMock()
|
||||
|
||||
self.patches = [
|
||||
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
||||
@@ -1143,13 +1133,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
MagicMock()),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
|
||||
MagicMock()),
|
||||
patch('vllm.distributed.parallel_state.get_dcp_group',
|
||||
return_value=self.mock_dcp_group),
|
||||
patch('vllm.distributed.parallel_state._DCP',
|
||||
return_value=self.mock_dcp),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size',
|
||||
return_value=1),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||
return_value=MagicMock()),
|
||||
|
||||
Reference in New Issue
Block a user