diff --git a/docs/source/tutorials/DeepSeek-V3.1.md b/docs/source/tutorials/DeepSeek-V3.1.md index 5db808a1..c1573189 100644 --- a/docs/source/tutorials/DeepSeek-V3.1.md +++ b/docs/source/tutorials/DeepSeek-V3.1.md @@ -430,7 +430,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ "engine_id": "0", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_extra_config": { - "use_ascend_direct": true, "prefill": { "dp_size": 2, "tp_size": 8 @@ -510,7 +509,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ "engine_id": "1", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_extra_config": { - "use_ascend_direct": true, "prefill": { "dp_size": 2, "tp_size": 8 @@ -590,7 +588,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ "engine_id": "2", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_extra_config": { - "use_ascend_direct": true, "prefill": { "dp_size": 2, "tp_size": 8 @@ -670,7 +667,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ "engine_id": "3", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_extra_config": { - "use_ascend_direct": true, "prefill": { "dp_size": 2, "tp_size": 8 diff --git a/docs/source/user_guide/feature_guide/kv_pool.md b/docs/source/user_guide/feature_guide/kv_pool.md index 208fbada..48b8e32b 100644 --- a/docs/source/user_guide/feature_guide/kv_pool.md +++ b/docs/source/user_guide/feature_guide/kv_pool.md @@ -41,7 +41,6 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path "metadata_server": "P2PHANDSHAKE", "protocol": "ascend", "device_name": "", - "use_ascend_direct": true, "alloc_in_same_node": true, "master_server_address": "xx.xx.xx.xx:50088", "global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824) @@ -52,7 +51,6 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path **metadata_server**: Configured as **P2PHANDSHAKE**. **protocol:** Configured for Ascend to use Mooncake's HCCL communication. **device_name**: "" -**use_ascend_direct**: Indicator for using ADXL engine. **alloc_in_same_node**: Indicator for preferring local buffer allocation strategy. **master_server_address**: Configured with the IP and port of the master service. **global_segment_size**: Expands the kvcache size registered by the PD node to the master. @@ -133,7 +131,7 @@ python3 -m vllm.entrypoints.openai.api_server \ } ] } - }' > p.log 2>&1 + }' ``` `decode` Node: @@ -177,7 +175,6 @@ python3 -m vllm.entrypoints.openai.api_server \ "kv_role": "kv_consumer", "kv_port": "20002", "kv_connector_extra_config": { - "use_ascend_direct": true, "prefill": { "dp_size": 1, "tp_size": 1 @@ -196,7 +193,7 @@ python3 -m vllm.entrypoints.openai.api_server \ } ] } - }' > d.log 2>&1 + }' ``` #### 2、Start proxy_server. diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 1179e328..92305170 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -639,10 +639,15 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): def setUp(self): config = MockVllmConfig() self.p1 = patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', - new=MagicMock(return_value=None)) + 'vllm_ascend.distributed.mooncake_connector.init_ascend_config', + new=MagicMock()) + self.p2 = patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + new=MagicMock(return_value=MagicMock())) self.p1.start() + self.p2.start() self.addCleanup(self.p1.stop) + self.addCleanup(self.p2.stop) self.scheduler = MooncakeConnectorScheduler(config, "test_engine") def test_get_num_new_matched_tokens(self): @@ -716,7 +721,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase): config = MockVllmConfig() with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) self.assertIsNotNone(connector.connector_scheduler) self.assertIsNone(connector.connector_worker) @@ -726,7 +733,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase): config = MockVllmConfig() with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") connector.get_num_new_matched_tokens(request, 0) @@ -756,7 +765,9 @@ class TestMooncakeConnector(unittest.TestCase): def test_scheduler_initialization(self): with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) self.assertIsNotNone(connector.connector_scheduler) @@ -766,7 +777,9 @@ class TestMooncakeConnector(unittest.TestCase): def test_get_num_new_matched_tokens(self, mock_method): with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") @@ -777,7 +790,9 @@ class TestMooncakeConnector(unittest.TestCase): def test_update_state_after_alloc(self, mock_method): with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") @@ -789,7 +804,9 @@ class TestMooncakeConnector(unittest.TestCase): def test_build_connector_meta(self, mock_method): with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) scheduler_output = MockSchedulerOutput() @@ -800,7 +817,9 @@ class TestMooncakeConnector(unittest.TestCase): def test_request_finished(self, mock_method): with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") @@ -814,7 +833,9 @@ class TestMooncakeConnectorScheduler(unittest.TestCase): self.config = MockVllmConfig() with patch( 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): + ), patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()): self.scheduler = MooncakeConnectorScheduler( self.config, "test_engine") @@ -1037,9 +1058,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.mock_pcp_group.device_group = MagicMock() self.patches = [ - patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES', - '10,11'), patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), patch('torch.Tensor.element_size', return_value=4), patch('torch.Tensor.data_ptr', return_value=0x1000), @@ -1056,8 +1074,11 @@ class TestMooncakeConnectorWorker(unittest.TestCase): 'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash', mock_string_to_int64_hash), patch( - 'vllm_ascend.distributed.mooncake_transfer_engine.TransferEngine', + 'vllm_ascend.distributed.mooncake_connector.global_te.get_transfer_engine', return_value=self.mock_transfer_engine), + patch( + 'vllm_ascend.distributed.mooncake_connector.global_te.register_buffer', + return_value=None), patch( 'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread', MagicMock()), @@ -1073,10 +1094,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase): patch('vllm.distributed.parallel_state._DCP', return_value=self.mock_dcp), patch( - 'vllm.distributed.get_decode_context_model_parallel_world_size', + 'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size', return_value=1), patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group', return_value=self.mock_pcp_group), + patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()), ] for p in self.patches: @@ -1090,46 +1114,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): for p in self.patches: p.stop() # type: ignore - def test_worker_use_ascend_direct(self): - test_case = [True, False] - - for use_ascend_direct in test_case: - with self.subTest(use_ascend_direct=use_ascend_direct): - config = MagicMock() - config.kv_transfer_config = MagicMock() - config.kv_transfer_config.get_from_extra_config.side_effect = ( - lambda k, d: { - "prefill": { - "tp_size": 2, - "dp_size": 1 - }, - "decode": { - "tp_size": 2, - "dp_size": 1 - }, - "use_ascend_direct": use_ascend_direct, - }.get(k, d)) - - config.parallel_config = MagicMock() - config.parallel_config.tensor_parallel_size = 2 - config.parallel_config.data_parallel_rank = 0 - config.parallel_config.data_parallel_size_local = 1 - config.kv_transfer_config.kv_port = 8000 - config.kv_transfer_config.kv_role = 'worker' - - with patch( - "vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank", - return_value=0): - with patch( - "vllm_ascend.distributed.mooncake_connector.get_tp_group", - return_value=None): - with patch( - "vllm_ascend.distributed.mooncake_connector.get_ip", - return_value="127.0.0.1"): - worker = MooncakeConnectorWorker( - config, self.engine_id) - self.assertIsNotNone(worker) - def test_register_kv_caches_producer(self): worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) @@ -1160,7 +1144,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase): # Test with physical devices set worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) # Default tp_rank is 0, so device_id should be 10 - self.assertEqual(worker.device_id, 10) + self.assertIsNotNone(worker.engine) if __name__ == '__main__': diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index 28504c9b..ca4f975c 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -58,6 +58,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): 6000], # 2 * total_layers use_mla=True, block_len=[1024, 2048], + decode_tp_size=1, first_kv_cache=self.first_kv_cache, callback_func=MagicMock()) @@ -97,6 +98,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): kv_cache_base_addr=[1111, 2222, 3333, 4444], use_mla=False, block_len=[64], + decode_tp_size=1, first_kv_cache=self.first_kv_cache, callback_func=MagicMock()) @@ -155,6 +157,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): kv_cache_base_addr=[1000, 2000], use_mla=False, block_len=[1024], + decode_tp_size=1, first_kv_cache=self.first_kv_cache, callback_func=MagicMock()) req_meta = self.req_meta_base @@ -397,7 +400,6 @@ class MockVllmConfig: "tp_size": 2, "dp_size": 1 }, - "use_ascend_direct": True, }.get(k, d) @@ -806,9 +808,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): self.mock_transfer_engine.register_memory.return_value = 0 self.patches = [ - patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES', - '10,11'), patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), patch('torch.Tensor.element_size', return_value=4), patch('torch.Tensor.data_ptr', return_value=0x1000), @@ -827,8 +826,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): 'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash', side_effect=lambda s: hash(s)), patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.TransferEngine', + 'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.get_transfer_engine', return_value=self.mock_transfer_engine), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.register_buffer', + return_value=None), patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheSendingLayerThread', MagicMock()), @@ -859,26 +861,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): for p in self.patches: p.stop() # type: ignore - def test_worker_use_ascend_direct(self): - for use_ascend_direct in (True, False): - with self.subTest(use_ascend_direct=use_ascend_direct): - config = MockVllmConfig() - config.kv_transfer_config.get_from_extra_config.side_effect = ( - lambda k, d: { - "prefill": { - "tp_size": 2, - "dp_size": 1 - }, - "decode": { - "tp_size": 2, - "dp_size": 1 - }, - "use_ascend_direct": use_ascend_direct, - }.get(k, d)) - worker = MooncakeLayerwiseConnectorWorker( - config, self.engine_id) - self.assertIsNotNone(worker) - def test_register_kv_caches_producer(self): self.vllm_config.kv_transfer_config.is_kv_producer = True @@ -915,7 +897,7 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): def test_device_id_selection_with_physical_devices(self): worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) - self.assertEqual(worker.device_id, 10) + self.assertIsNotNone(worker.engine) if __name__ == '__main__': diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8b485c17..612e2bb1 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -28,6 +28,37 @@ def _check_torchair_supported(model_type: str): return False +def check_kv_extra_config(vllm_config): + + def _check(name: str, config: dict): + tp_key = "tp_size" + dp_key = "dp_size" + if tp_key in config: + config_tp = config[tp_key] + vllm_tp = vllm_config.parallel_config.tensor_parallel_size + if config_tp != vllm_tp: + raise ValueError( + f"KV transfer '{name}' config has a conflicting tensor parallel size. " + f"Expected {vllm_tp}, but got {config_tp}.") + if dp_key in config: + config_dp = config[dp_key] + vllm_dp = vllm_config.parallel_config.data_parallel_size + if config_dp != vllm_dp: + raise ValueError( + f"KV transfer '{name}' config has a conflicting data parallel size. " + f"Expected {vllm_dp}, but got {config_dp}.") + + if vllm_config.kv_transfer_config.is_kv_producer: + _check( + "prefill", + vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {})) + if vllm_config.kv_transfer_config.is_kv_consumer: + _check( + "decode", + vllm_config.kv_transfer_config.get_from_extra_config("decode", {})) + + class AscendConfig: """ Configuration Object for additional_config from vllm.configs. @@ -112,6 +143,10 @@ class AscendConfig: ) self.enable_cpu_binding = additional_config.get( "enable_cpu_binding", False) + + if vllm_config.kv_transfer_config is not None: + check_kv_extra_config(vllm_config) + self.pd_tp_ratio = 1 self.pd_head_ratio = 1 self.num_head_replica = 1 diff --git a/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py index 7d9bfedd..6fb0d259 100644 --- a/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py +++ b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py @@ -83,7 +83,6 @@ class MooncakeStoreConfig: protocol: str device_name: str master_server_address: str - use_ascend_direct: bool @staticmethod def from_file(file_path: str) -> "MooncakeStoreConfig": @@ -99,8 +98,7 @@ class MooncakeStoreConfig: DEFAULT_LOCAL_BUFFER_SIZE)), protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address"), - use_ascend_direct=config.get("use_ascend_direct", False)) + master_server_address=config.get("master_server_address")) @staticmethod def load_from_env() -> "MooncakeStoreConfig": diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 8fe21146..8303c2dc 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -35,7 +35,6 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import RequestStatus -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value @@ -653,6 +652,7 @@ class MooncakeConnector(KVConnectorBase_V1): kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id + self._connector_metadata = MooncakeConnectorMetadata() if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ @@ -744,9 +744,6 @@ class MooncakeConnectorScheduler: self.side_channel_host = get_ip() self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size - self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ - vllm_config.parallel_config.data_parallel_size * \ - self.pcp_size # Handshake base port self.side_channel_port = ( @@ -905,8 +902,6 @@ class MooncakeConnectorWorker: self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.tp_group = get_tp_group() - self.dp_rank = vllm_config.parallel_config.data_parallel_rank - self.dp_size = vllm_config.parallel_config.data_parallel_size_local self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() self.pcp_size = get_pcp_group().world_size @@ -916,7 +911,6 @@ class MooncakeConnectorWorker: self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 - self.max_device_id = self.tp_size * self.dp_size * self.pcp_size self.kv_role = vllm_config.kv_transfer_config.kv_role self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads @@ -927,38 +921,9 @@ class MooncakeConnectorWorker: vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank self.sockets: dict = {} - - # get tp device id - # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940 - # introducing some changes - device_ids_str = envs_ascend.PHYSICAL_DEVICES - if device_ids_str is None: - device_ids = list( - range(self.dp_rank * self.tp_size * self.pcp_size, - (self.dp_rank + 1) * self.tp_size * self.pcp_size)) - else: - device_ids = list(map(int, device_ids_str.split(','))) - start_index = self.dp_rank * self.tp_size * self.pcp_size - end_index = start_index + self.tp_size * self.pcp_size - if len(device_ids) < end_index: - raise ValueError( - f"Not enough physical devices available for DP rank {self.dp_rank}. " - f"Expected at least {end_index} devices, but found {len(device_ids)} " - "in PHYSICAL_DEVICES.") - device_ids = device_ids[start_index:end_index] - assert len( - device_ids - ) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore - self.device_id = device_ids[self.pcp_rank * self.tp_size + - self.tp_rank] # type: ignore - - if vllm_config.kv_transfer_config.get_from_extra_config( - 'use_ascend_direct', True): - hostname = self.side_channel_host - else: - hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" logger.info("Initializing Mooncake work %s", engine_id) - self.engine = global_te.get_transfer_engine(hostname, device_name=None) + self.engine = global_te.get_transfer_engine(self.side_channel_host, + device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. @@ -998,19 +963,6 @@ class MooncakeConnectorWorker: assert "dp_size" in decode_parallel_config.keys() self._decode_dp_size = decode_parallel_config["dp_size"] - def _initialize( - self, - hostname: str, - device_name: Optional[str], - ) -> None: - """Initialize the mooncake instance.""" - device_name = device_name if device_name is not None else "" - ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend", - device_name) - if ret_value != 0: - raise RuntimeError( - f"Mooncake initialization failed with ret_value: {ret_value}") - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 1f5c44e1..d1351049 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -32,8 +32,8 @@ from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import (align_memory, get_transfer_timeout_value, kv_alltoall_and_rearrange) @@ -100,6 +100,7 @@ class KVCacheSendingLayerThread(threading.Thread): kv_cache_base_addr: list[int], use_mla: bool, block_len: list[int], + decode_tp_size: int, first_kv_cache: torch.Tensor, callback_func: Callable[..., None] = lambda x: None): super().__init__(daemon=True, name="KVCacheSendingLayerThread") @@ -111,6 +112,7 @@ class KVCacheSendingLayerThread(threading.Thread): self.total_layers = total_layers self.use_mla = use_mla self.block_len = block_len + self._decode_tp_size = decode_tp_size self.model_stream = torch_npu.npu.current_stream() self.current_layer = -1 @@ -172,9 +174,20 @@ class KVCacheSendingLayerThread(threading.Thread): def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): # send kv layer to remote if len(req_meta.local_block_ids) == 0: + logger.debug( + f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer." + ) return # not need to send kv cache if self.tp_rank % self.num_head_replica != 0: + logger.debug( + f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})." + ) + return + if self.use_mla and self.tp_rank >= self._decode_tp_size: + logger.debug( + f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})." + ) return remote_host = req_meta.remote_host @@ -484,8 +497,6 @@ class MooncakeLayerwiseConnectorScheduler: logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() - self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ - vllm_config.parallel_config.data_parallel_size # Handshake base port self.side_channel_port = ( @@ -550,6 +561,9 @@ class MooncakeLayerwiseConnectorScheduler: local_block_ids = (blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []) # Get unhashed blocks to pull from remote. + logger.debug( + f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue" + ) self._reqs_need_recv[request.request_id] = ( request, [], #request._all_token_ids, @@ -560,6 +574,9 @@ class MooncakeLayerwiseConnectorScheduler: # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): local_block_ids = (blocks.get_block_ids()[0]) + logger.debug( + f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" + ) self._reqs_need_send_layerwise[request.request_id] = (len( request.all_token_ids), local_block_ids, request) @@ -603,12 +620,19 @@ class MooncakeLayerwiseConnectorScheduler: req_id] current_tokens = computed_tokens.get(req_id, 0) + scheduled_tokens - if current_tokens == total_tokens: + if current_tokens >= total_tokens: + logger.debug( + f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" + ) meta.add_new_req(request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, token_ids=[]) self._reqs_need_send_layerwise.pop(req_id) + else: + logger.debug( + f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" + ) return meta def request_finished( @@ -639,7 +663,6 @@ class MooncakeLayerwiseConnectorWorker: if TransferEngine is None: raise RuntimeError("mooncake is not available") logger.info("Initializing Mooncake work %s", engine_id) - self.engine = TransferEngine() # Metadata. self.vllm_config = vllm_config @@ -648,11 +671,8 @@ class MooncakeLayerwiseConnectorWorker: self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.tp_group = get_tp_group() - self.dp_rank = vllm_config.parallel_config.data_parallel_rank - self.dp_size = vllm_config.parallel_config.data_parallel_size_local self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() - self.max_device_id = self.tp_size * self.dp_size self.total_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) @@ -668,34 +688,9 @@ class MooncakeLayerwiseConnectorWorker: vllm_config.parallel_config.tensor_parallel_size) self.handshake_port = self.side_channel_port + self.tp_rank self.sockets: dict = {} - - # get tp device id - # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940 - # introducing some changes - device_ids_str = envs_ascend.PHYSICAL_DEVICES - if device_ids_str is None: - device_ids = list( - range(self.dp_rank * self.tp_size, - (self.dp_rank + 1) * self.tp_size)) - else: - device_ids = list(map(int, device_ids_str.split(','))) - start_index = self.dp_rank * self.tp_size - end_index = start_index + self.tp_size - if len(device_ids) < end_index: - raise ValueError( - f"Not enough physical devices available for DP rank {self.dp_rank}. " - f"Expected at least {end_index} devices, but found {len(device_ids)} " - "in PHYSICAL_DEVICES.") - device_ids = device_ids[start_index:end_index] - assert len(device_ids) > self.tp_rank # type: ignore - self.device_id = device_ids[self.tp_rank] # type: ignore - - if vllm_config.kv_transfer_config.get_from_extra_config( - 'use_ascend_direct', True): - hostname = self.side_channel_host - else: - hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" - self._initialize(hostname=hostname, device_name=None) + logger.info("Initializing Mooncake work %s", engine_id) + self.engine = global_te.get_transfer_engine(self.side_channel_host, + device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. @@ -747,19 +742,6 @@ class MooncakeLayerwiseConnectorWorker: assert "dp_size" in decode_parallel_config.keys() self._decode_dp_size = decode_parallel_config["dp_size"] - def _initialize( - self, - hostname: str, - device_name: Optional[str], - ) -> None: - """Initialize the mooncake instance.""" - device_name = device_name if device_name is not None else "" - ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend", - device_name) - if ret_value != 0: - raise RuntimeError( - f"Mooncake initialization failed with ret_value: {ret_value}") - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -798,6 +780,8 @@ class MooncakeLayerwiseConnectorWorker: self.kv_caches = kv_caches kv_caches_base_addr = [] + ptrs = [] + lengths = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: @@ -805,7 +789,8 @@ class MooncakeLayerwiseConnectorWorker: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) else: cache_list = [cache_or_caches ] if self.use_mla else cache_or_caches @@ -813,7 +798,9 @@ class MooncakeLayerwiseConnectorWorker: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) + global_te.register_buffer(ptrs, lengths) self.kv_caches_base_addr = kv_caches_base_addr # After KV Caches registered, start the sending or receiving thread. @@ -833,6 +820,7 @@ class MooncakeLayerwiseConnectorWorker: kv_cache_base_addr=self.kv_caches_base_addr, use_mla=self.use_mla, block_len=self.block_len, + decode_tp_size=self._decode_tp_size, first_kv_cache=first_kv_cache, callback_func=self.send_done_send_signal) self.kv_send_layer_thread.start() @@ -846,14 +834,6 @@ class MooncakeLayerwiseConnectorWorker: self.kv_recv_layer_thread.start() ready_event.wait() - def _register(self, ptr, length): - logger.info( - "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " - "block_lens=%s", ptr, length, self.num_blocks, self.block_len) - ret_value = self.engine.register_memory(ptr, length) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed.") - def _access_metaserver(self, url, message): success = False retry = 0 @@ -969,9 +949,6 @@ class MooncakeLayerwiseConnectorWorker: key = None value = None for req_id, req_meta in connector_metadata.requests.items(): - logger.debug( - f"Add request {req_id} to kv send layer thread. {req_meta=}" - ) if self.pd_head_ratio != 1: key_block_num = len( req_meta.local_block_ids) * key_block_size @@ -983,6 +960,9 @@ class MooncakeLayerwiseConnectorWorker: key_start_id += key_block_num value_start_id += value_block_num req_meta_update = self.update_decoder_info(req_id, req_meta) + logger.debug( + f"Add request {req_id} to kv send layer thread. {req_meta_update=}" + ) assert self.kv_send_layer_thread is not None self.kv_send_layer_thread.send_queue.put( (req_id, req_meta_update, self.current_layer, key, value)) @@ -1011,10 +991,8 @@ class MooncakeLayerwiseConnectorWorker: def update_decoder_info(self, req_id, req_meta): req_meta_update = copy.deepcopy(req_meta) - if self.pd_tp_ratio > 1: - req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank // self.pd_tp_ratio - else: - req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank + req_meta_update.remote_port = req_meta_update.remote_port + ( + self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: try: