[P/D] check kv extra config and del hccl backend (#4547)

### What this PR does / why we need it?
check kv extra config & del hccl backend


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
liziyu
2025-12-07 15:19:42 +08:00
committed by GitHub
parent b91a5f0968
commit 688b1332da
8 changed files with 133 additions and 211 deletions

View File

@@ -430,7 +430,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
"engine_id": "0", "engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": { "prefill": {
"dp_size": 2, "dp_size": 2,
"tp_size": 8 "tp_size": 8
@@ -510,7 +509,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
"engine_id": "1", "engine_id": "1",
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": { "prefill": {
"dp_size": 2, "dp_size": 2,
"tp_size": 8 "tp_size": 8
@@ -590,7 +588,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
"engine_id": "2", "engine_id": "2",
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": { "prefill": {
"dp_size": 2, "dp_size": 2,
"tp_size": 8 "tp_size": 8
@@ -670,7 +667,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
"engine_id": "3", "engine_id": "3",
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": { "prefill": {
"dp_size": 2, "dp_size": 2,
"tp_size": 8 "tp_size": 8

View File

@@ -41,7 +41,6 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
"metadata_server": "P2PHANDSHAKE", "metadata_server": "P2PHANDSHAKE",
"protocol": "ascend", "protocol": "ascend",
"device_name": "", "device_name": "",
"use_ascend_direct": true,
"alloc_in_same_node": true, "alloc_in_same_node": true,
"master_server_address": "xx.xx.xx.xx:50088", "master_server_address": "xx.xx.xx.xx:50088",
"global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824) "global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824)
@@ -52,7 +51,6 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
**metadata_server**: Configured as **P2PHANDSHAKE**. **metadata_server**: Configured as **P2PHANDSHAKE**.
**protocol:** Configured for Ascend to use Mooncake's HCCL communication. **protocol:** Configured for Ascend to use Mooncake's HCCL communication.
**device_name**: "" **device_name**: ""
**use_ascend_direct**: Indicator for using ADXL engine.
**alloc_in_same_node**: Indicator for preferring local buffer allocation strategy. **alloc_in_same_node**: Indicator for preferring local buffer allocation strategy.
**master_server_address**: Configured with the IP and port of the master service. **master_server_address**: Configured with the IP and port of the master service.
**global_segment_size**: Expands the kvcache size registered by the PD node to the master. **global_segment_size**: Expands the kvcache size registered by the PD node to the master.
@@ -133,7 +131,7 @@ python3 -m vllm.entrypoints.openai.api_server \
} }
] ]
} }
}' > p.log 2>&1 }'
``` ```
`decode` Node `decode` Node
@@ -177,7 +175,6 @@ python3 -m vllm.entrypoints.openai.api_server \
"kv_role": "kv_consumer", "kv_role": "kv_consumer",
"kv_port": "20002", "kv_port": "20002",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": { "prefill": {
"dp_size": 1, "dp_size": 1,
"tp_size": 1 "tp_size": 1
@@ -196,7 +193,7 @@ python3 -m vllm.entrypoints.openai.api_server \
} }
] ]
} }
}' > d.log 2>&1 }'
``` ```
#### 2、Start proxy_server. #### 2、Start proxy_server.

View File

@@ -639,10 +639,15 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
def setUp(self): def setUp(self):
config = MockVllmConfig() config = MockVllmConfig()
self.p1 = patch( self.p1 = patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', 'vllm_ascend.distributed.mooncake_connector.init_ascend_config',
new=MagicMock(return_value=None)) new=MagicMock())
self.p2 = patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
new=MagicMock(return_value=MagicMock()))
self.p1.start() self.p1.start()
self.p2.start()
self.addCleanup(self.p1.stop) self.addCleanup(self.p1.stop)
self.addCleanup(self.p2.stop)
self.scheduler = MooncakeConnectorScheduler(config, "test_engine") self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
def test_get_num_new_matched_tokens(self): def test_get_num_new_matched_tokens(self):
@@ -716,7 +721,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
config = MockVllmConfig() config = MockVllmConfig()
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler) self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker) self.assertIsNone(connector.connector_worker)
@@ -726,7 +733,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
config = MockVllmConfig() config = MockVllmConfig()
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0) connector.get_num_new_matched_tokens(request, 0)
@@ -756,7 +765,9 @@ class TestMooncakeConnector(unittest.TestCase):
def test_scheduler_initialization(self): def test_scheduler_initialization(self):
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config, connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER) KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler) self.assertIsNotNone(connector.connector_scheduler)
@@ -766,7 +777,9 @@ class TestMooncakeConnector(unittest.TestCase):
def test_get_num_new_matched_tokens(self, mock_method): def test_get_num_new_matched_tokens(self, mock_method):
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config, connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER) KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
@@ -777,7 +790,9 @@ class TestMooncakeConnector(unittest.TestCase):
def test_update_state_after_alloc(self, mock_method): def test_update_state_after_alloc(self, mock_method):
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config, connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER) KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
@@ -789,7 +804,9 @@ class TestMooncakeConnector(unittest.TestCase):
def test_build_connector_meta(self, mock_method): def test_build_connector_meta(self, mock_method):
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config, connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER) KVConnectorRole.SCHEDULER)
scheduler_output = MockSchedulerOutput() scheduler_output = MockSchedulerOutput()
@@ -800,7 +817,9 @@ class TestMooncakeConnector(unittest.TestCase):
def test_request_finished(self, mock_method): def test_request_finished(self, mock_method):
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config, connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER) KVConnectorRole.SCHEDULER)
request = MockRequest("req1") request = MockRequest("req1")
@@ -814,7 +833,9 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
self.config = MockVllmConfig() self.config = MockVllmConfig()
with patch( with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config' 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
): ), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
self.scheduler = MooncakeConnectorScheduler( self.scheduler = MooncakeConnectorScheduler(
self.config, "test_engine") self.config, "test_engine")
@@ -1037,9 +1058,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
self.mock_pcp_group.device_group = MagicMock() self.mock_pcp_group.device_group = MagicMock()
self.patches = [ self.patches = [
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
'10,11'),
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
patch('torch.Tensor.element_size', return_value=4), patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000), patch('torch.Tensor.data_ptr', return_value=0x1000),
@@ -1056,8 +1074,11 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash', 'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
mock_string_to_int64_hash), mock_string_to_int64_hash),
patch( patch(
'vllm_ascend.distributed.mooncake_transfer_engine.TransferEngine', 'vllm_ascend.distributed.mooncake_connector.global_te.get_transfer_engine',
return_value=self.mock_transfer_engine), return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_connector.global_te.register_buffer',
return_value=None),
patch( patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread', 'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
MagicMock()), MagicMock()),
@@ -1073,10 +1094,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
patch('vllm.distributed.parallel_state._DCP', patch('vllm.distributed.parallel_state._DCP',
return_value=self.mock_dcp), return_value=self.mock_dcp),
patch( patch(
'vllm.distributed.get_decode_context_model_parallel_world_size', 'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size',
return_value=1), return_value=1),
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group', patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
return_value=self.mock_pcp_group), return_value=self.mock_pcp_group),
patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()),
] ]
for p in self.patches: for p in self.patches:
@@ -1090,46 +1114,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
for p in self.patches: for p in self.patches:
p.stop() # type: ignore p.stop() # type: ignore
def test_worker_use_ascend_direct(self):
test_case = [True, False]
for use_ascend_direct in test_case:
with self.subTest(use_ascend_direct=use_ascend_direct):
config = MagicMock()
config.kv_transfer_config = MagicMock()
config.kv_transfer_config.get_from_extra_config.side_effect = (
lambda k, d: {
"prefill": {
"tp_size": 2,
"dp_size": 1
},
"decode": {
"tp_size": 2,
"dp_size": 1
},
"use_ascend_direct": use_ascend_direct,
}.get(k, d))
config.parallel_config = MagicMock()
config.parallel_config.tensor_parallel_size = 2
config.parallel_config.data_parallel_rank = 0
config.parallel_config.data_parallel_size_local = 1
config.kv_transfer_config.kv_port = 8000
config.kv_transfer_config.kv_role = 'worker'
with patch(
"vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank",
return_value=0):
with patch(
"vllm_ascend.distributed.mooncake_connector.get_tp_group",
return_value=None):
with patch(
"vllm_ascend.distributed.mooncake_connector.get_ip",
return_value="127.0.0.1"):
worker = MooncakeConnectorWorker(
config, self.engine_id)
self.assertIsNotNone(worker)
def test_register_kv_caches_producer(self): def test_register_kv_caches_producer(self):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(self.kv_caches) worker.register_kv_caches(self.kv_caches)
@@ -1160,7 +1144,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
# Test with physical devices set # Test with physical devices set
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
# Default tp_rank is 0, so device_id should be 10 # Default tp_rank is 0, so device_id should be 10
self.assertEqual(worker.device_id, 10) self.assertIsNotNone(worker.engine)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -58,6 +58,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
6000], # 2 * total_layers 6000], # 2 * total_layers
use_mla=True, use_mla=True,
block_len=[1024, 2048], block_len=[1024, 2048],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache, first_kv_cache=self.first_kv_cache,
callback_func=MagicMock()) callback_func=MagicMock())
@@ -97,6 +98,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
kv_cache_base_addr=[1111, 2222, 3333, 4444], kv_cache_base_addr=[1111, 2222, 3333, 4444],
use_mla=False, use_mla=False,
block_len=[64], block_len=[64],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache, first_kv_cache=self.first_kv_cache,
callback_func=MagicMock()) callback_func=MagicMock())
@@ -155,6 +157,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
kv_cache_base_addr=[1000, 2000], kv_cache_base_addr=[1000, 2000],
use_mla=False, use_mla=False,
block_len=[1024], block_len=[1024],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache, first_kv_cache=self.first_kv_cache,
callback_func=MagicMock()) callback_func=MagicMock())
req_meta = self.req_meta_base req_meta = self.req_meta_base
@@ -397,7 +400,6 @@ class MockVllmConfig:
"tp_size": 2, "tp_size": 2,
"dp_size": 1 "dp_size": 1
}, },
"use_ascend_direct": True,
}.get(k, d) }.get(k, d)
@@ -806,9 +808,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
self.mock_transfer_engine.register_memory.return_value = 0 self.mock_transfer_engine.register_memory.return_value = 0
self.patches = [ self.patches = [
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
'10,11'),
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
patch('torch.Tensor.element_size', return_value=4), patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000), patch('torch.Tensor.data_ptr', return_value=0x1000),
@@ -827,8 +826,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash', 'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash',
side_effect=lambda s: hash(s)), side_effect=lambda s: hash(s)),
patch( patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.TransferEngine', 'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.get_transfer_engine',
return_value=self.mock_transfer_engine), return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.register_buffer',
return_value=None),
patch( patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheSendingLayerThread', 'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheSendingLayerThread',
MagicMock()), MagicMock()),
@@ -859,26 +861,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
for p in self.patches: for p in self.patches:
p.stop() # type: ignore p.stop() # type: ignore
def test_worker_use_ascend_direct(self):
for use_ascend_direct in (True, False):
with self.subTest(use_ascend_direct=use_ascend_direct):
config = MockVllmConfig()
config.kv_transfer_config.get_from_extra_config.side_effect = (
lambda k, d: {
"prefill": {
"tp_size": 2,
"dp_size": 1
},
"decode": {
"tp_size": 2,
"dp_size": 1
},
"use_ascend_direct": use_ascend_direct,
}.get(k, d))
worker = MooncakeLayerwiseConnectorWorker(
config, self.engine_id)
self.assertIsNotNone(worker)
def test_register_kv_caches_producer(self): def test_register_kv_caches_producer(self):
self.vllm_config.kv_transfer_config.is_kv_producer = True self.vllm_config.kv_transfer_config.is_kv_producer = True
@@ -915,7 +897,7 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
def test_device_id_selection_with_physical_devices(self): def test_device_id_selection_with_physical_devices(self):
worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, worker = MooncakeLayerwiseConnectorWorker(self.vllm_config,
self.engine_id) self.engine_id)
self.assertEqual(worker.device_id, 10) self.assertIsNotNone(worker.engine)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -28,6 +28,37 @@ def _check_torchair_supported(model_type: str):
return False return False
def check_kv_extra_config(vllm_config):
def _check(name: str, config: dict):
tp_key = "tp_size"
dp_key = "dp_size"
if tp_key in config:
config_tp = config[tp_key]
vllm_tp = vllm_config.parallel_config.tensor_parallel_size
if config_tp != vllm_tp:
raise ValueError(
f"KV transfer '{name}' config has a conflicting tensor parallel size. "
f"Expected {vllm_tp}, but got {config_tp}.")
if dp_key in config:
config_dp = config[dp_key]
vllm_dp = vllm_config.parallel_config.data_parallel_size
if config_dp != vllm_dp:
raise ValueError(
f"KV transfer '{name}' config has a conflicting data parallel size. "
f"Expected {vllm_dp}, but got {config_dp}.")
if vllm_config.kv_transfer_config.is_kv_producer:
_check(
"prefill",
vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {}))
if vllm_config.kv_transfer_config.is_kv_consumer:
_check(
"decode",
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
class AscendConfig: class AscendConfig:
""" """
Configuration Object for additional_config from vllm.configs. Configuration Object for additional_config from vllm.configs.
@@ -112,6 +143,10 @@ class AscendConfig:
) )
self.enable_cpu_binding = additional_config.get( self.enable_cpu_binding = additional_config.get(
"enable_cpu_binding", False) "enable_cpu_binding", False)
if vllm_config.kv_transfer_config is not None:
check_kv_extra_config(vllm_config)
self.pd_tp_ratio = 1 self.pd_tp_ratio = 1
self.pd_head_ratio = 1 self.pd_head_ratio = 1
self.num_head_replica = 1 self.num_head_replica = 1

View File

@@ -83,7 +83,6 @@ class MooncakeStoreConfig:
protocol: str protocol: str
device_name: str device_name: str
master_server_address: str master_server_address: str
use_ascend_direct: bool
@staticmethod @staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig": def from_file(file_path: str) -> "MooncakeStoreConfig":
@@ -99,8 +98,7 @@ class MooncakeStoreConfig:
DEFAULT_LOCAL_BUFFER_SIZE)), DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "tcp"), protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""), device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"), master_server_address=config.get("master_server_address"))
use_ascend_direct=config.get("use_ascend_direct", False))
@staticmethod @staticmethod
def load_from_env() -> "MooncakeStoreConfig": def load_from_env() -> "MooncakeStoreConfig":

View File

@@ -35,7 +35,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.utils import get_transfer_timeout_value from vllm_ascend.distributed.utils import get_transfer_timeout_value
@@ -653,6 +652,7 @@ class MooncakeConnector(KVConnectorBase_V1):
kv_cache_config: Optional[KVCacheConfig] = None): kv_cache_config: Optional[KVCacheConfig] = None):
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id self.engine_id = vllm_config.kv_transfer_config.engine_id
self._connector_metadata = MooncakeConnectorMetadata()
if role == KVConnectorRole.SCHEDULER: if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \
@@ -744,9 +744,6 @@ class MooncakeConnectorScheduler:
self.side_channel_host = get_ip() self.side_channel_host = get_ip()
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
vllm_config.parallel_config.data_parallel_size * \
self.pcp_size
# Handshake base port # Handshake base port
self.side_channel_port = ( self.side_channel_port = (
@@ -905,8 +902,6 @@ class MooncakeConnectorWorker:
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
self.kv_caches: dict[str, torch.Tensor] = {} self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip() self.side_channel_host = get_ip()
self.pcp_size = get_pcp_group().world_size self.pcp_size = get_pcp_group().world_size
@@ -916,7 +911,6 @@ class MooncakeConnectorWorker:
self.dcp_rank = get_decode_context_model_parallel_rank( self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0 ) if self.dcp_size > 1 else 0
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
@@ -927,38 +921,9 @@ class MooncakeConnectorWorker:
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
self.sockets: dict = {} self.sockets: dict = {}
# get tp device id
# TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
# introducing some changes
device_ids_str = envs_ascend.PHYSICAL_DEVICES
if device_ids_str is None:
device_ids = list(
range(self.dp_rank * self.tp_size * self.pcp_size,
(self.dp_rank + 1) * self.tp_size * self.pcp_size))
else:
device_ids = list(map(int, device_ids_str.split(',')))
start_index = self.dp_rank * self.tp_size * self.pcp_size
end_index = start_index + self.tp_size * self.pcp_size
if len(device_ids) < end_index:
raise ValueError(
f"Not enough physical devices available for DP rank {self.dp_rank}. "
f"Expected at least {end_index} devices, but found {len(device_ids)} "
"in PHYSICAL_DEVICES.")
device_ids = device_ids[start_index:end_index]
assert len(
device_ids
) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore
self.device_id = device_ids[self.pcp_rank * self.tp_size +
self.tp_rank] # type: ignore
if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', True):
hostname = self.side_channel_host
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
logger.info("Initializing Mooncake work %s", engine_id) logger.info("Initializing Mooncake work %s", engine_id)
self.engine = global_te.get_transfer_engine(hostname, device_name=None) self.engine = global_te.get_transfer_engine(self.side_channel_host,
device_name=None)
self.te_rpc_port = self.engine.get_rpc_port() self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches. # Background thread for sending or receiving KV caches.
@@ -998,19 +963,6 @@ class MooncakeConnectorWorker:
assert "dp_size" in decode_parallel_config.keys() assert "dp_size" in decode_parallel_config.keys()
self._decode_dp_size = decode_parallel_config["dp_size"] self._decode_dp_size = decode_parallel_config["dp_size"]
def _initialize(
self,
hostname: str,
device_name: Optional[str],
) -> None:
"""Initialize the mooncake instance."""
device_name = device_name if device_name is not None else ""
ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend",
device_name)
if ret_value != 0:
raise RuntimeError(
f"Mooncake initialization failed with ret_value: {ret_value}")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data.""" """Register the KV Cache data."""

View File

@@ -32,8 +32,8 @@ from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.utils import (align_memory, from vllm_ascend.distributed.utils import (align_memory,
get_transfer_timeout_value, get_transfer_timeout_value,
kv_alltoall_and_rearrange) kv_alltoall_and_rearrange)
@@ -100,6 +100,7 @@ class KVCacheSendingLayerThread(threading.Thread):
kv_cache_base_addr: list[int], kv_cache_base_addr: list[int],
use_mla: bool, use_mla: bool,
block_len: list[int], block_len: list[int],
decode_tp_size: int,
first_kv_cache: torch.Tensor, first_kv_cache: torch.Tensor,
callback_func: Callable[..., None] = lambda x: None): callback_func: Callable[..., None] = lambda x: None):
super().__init__(daemon=True, name="KVCacheSendingLayerThread") super().__init__(daemon=True, name="KVCacheSendingLayerThread")
@@ -111,6 +112,7 @@ class KVCacheSendingLayerThread(threading.Thread):
self.total_layers = total_layers self.total_layers = total_layers
self.use_mla = use_mla self.use_mla = use_mla
self.block_len = block_len self.block_len = block_len
self._decode_tp_size = decode_tp_size
self.model_stream = torch_npu.npu.current_stream() self.model_stream = torch_npu.npu.current_stream()
self.current_layer = -1 self.current_layer = -1
@@ -172,9 +174,20 @@ class KVCacheSendingLayerThread(threading.Thread):
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
# send kv layer to remote # send kv layer to remote
if len(req_meta.local_block_ids) == 0: if len(req_meta.local_block_ids) == 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer."
)
return return
# not need to send kv cache # not need to send kv cache
if self.tp_rank % self.num_head_replica != 0: if self.tp_rank % self.num_head_replica != 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
)
return
if self.use_mla and self.tp_rank >= self._decode_tp_size:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
)
return return
remote_host = req_meta.remote_host remote_host = req_meta.remote_host
@@ -484,8 +497,6 @@ class MooncakeLayerwiseConnectorScheduler:
logger.info("Initializing Mooncake Scheduler %s", engine_id) logger.info("Initializing Mooncake Scheduler %s", engine_id)
self.side_channel_host = get_ip() self.side_channel_host = get_ip()
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
vllm_config.parallel_config.data_parallel_size
# Handshake base port # Handshake base port
self.side_channel_port = ( self.side_channel_port = (
@@ -550,6 +561,9 @@ class MooncakeLayerwiseConnectorScheduler:
local_block_ids = (blocks.get_unhashed_block_ids() local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else []) if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote. # Get unhashed blocks to pull from remote.
logger.debug(
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue"
)
self._reqs_need_recv[request.request_id] = ( self._reqs_need_recv[request.request_id] = (
request, request,
[], #request._all_token_ids, [], #request._all_token_ids,
@@ -560,6 +574,9 @@ class MooncakeLayerwiseConnectorScheduler:
# Layerwise prefiller add request need send # Layerwise prefiller add request need send
if params is not None and params.get("do_remote_decode"): if params is not None and params.get("do_remote_decode"):
local_block_ids = (blocks.get_block_ids()[0]) local_block_ids = (blocks.get_block_ids()[0])
logger.debug(
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
)
self._reqs_need_send_layerwise[request.request_id] = (len( self._reqs_need_send_layerwise[request.request_id] = (len(
request.all_token_ids), local_block_ids, request) request.all_token_ids), local_block_ids, request)
@@ -603,12 +620,19 @@ class MooncakeLayerwiseConnectorScheduler:
req_id] req_id]
current_tokens = computed_tokens.get(req_id, current_tokens = computed_tokens.get(req_id,
0) + scheduled_tokens 0) + scheduled_tokens
if current_tokens == total_tokens: if current_tokens >= total_tokens:
logger.debug(
f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
)
meta.add_new_req(request_id=req_id, meta.add_new_req(request_id=req_id,
local_block_ids=block_ids, local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params,
token_ids=[]) token_ids=[])
self._reqs_need_send_layerwise.pop(req_id) self._reqs_need_send_layerwise.pop(req_id)
else:
logger.debug(
f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
)
return meta return meta
def request_finished( def request_finished(
@@ -639,7 +663,6 @@ class MooncakeLayerwiseConnectorWorker:
if TransferEngine is None: if TransferEngine is None:
raise RuntimeError("mooncake is not available") raise RuntimeError("mooncake is not available")
logger.info("Initializing Mooncake work %s", engine_id) logger.info("Initializing Mooncake work %s", engine_id)
self.engine = TransferEngine()
# Metadata. # Metadata.
self.vllm_config = vllm_config self.vllm_config = vllm_config
@@ -648,11 +671,8 @@ class MooncakeLayerwiseConnectorWorker:
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
self.kv_caches: dict[str, torch.Tensor] = {} self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip() self.side_channel_host = get_ip()
self.max_device_id = self.tp_size * self.dp_size
self.total_layers = vllm_config.model_config.get_num_layers( self.total_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config) vllm_config.parallel_config)
@@ -668,34 +688,9 @@ class MooncakeLayerwiseConnectorWorker:
vllm_config.parallel_config.tensor_parallel_size) vllm_config.parallel_config.tensor_parallel_size)
self.handshake_port = self.side_channel_port + self.tp_rank self.handshake_port = self.side_channel_port + self.tp_rank
self.sockets: dict = {} self.sockets: dict = {}
logger.info("Initializing Mooncake work %s", engine_id)
# get tp device id self.engine = global_te.get_transfer_engine(self.side_channel_host,
# TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940 device_name=None)
# introducing some changes
device_ids_str = envs_ascend.PHYSICAL_DEVICES
if device_ids_str is None:
device_ids = list(
range(self.dp_rank * self.tp_size,
(self.dp_rank + 1) * self.tp_size))
else:
device_ids = list(map(int, device_ids_str.split(',')))
start_index = self.dp_rank * self.tp_size
end_index = start_index + self.tp_size
if len(device_ids) < end_index:
raise ValueError(
f"Not enough physical devices available for DP rank {self.dp_rank}. "
f"Expected at least {end_index} devices, but found {len(device_ids)} "
"in PHYSICAL_DEVICES.")
device_ids = device_ids[start_index:end_index]
assert len(device_ids) > self.tp_rank # type: ignore
self.device_id = device_ids[self.tp_rank] # type: ignore
if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', True):
hostname = self.side_channel_host
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
self._initialize(hostname=hostname, device_name=None)
self.te_rpc_port = self.engine.get_rpc_port() self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches. # Background thread for sending or receiving KV caches.
@@ -747,19 +742,6 @@ class MooncakeLayerwiseConnectorWorker:
assert "dp_size" in decode_parallel_config.keys() assert "dp_size" in decode_parallel_config.keys()
self._decode_dp_size = decode_parallel_config["dp_size"] self._decode_dp_size = decode_parallel_config["dp_size"]
def _initialize(
self,
hostname: str,
device_name: Optional[str],
) -> None:
"""Initialize the mooncake instance."""
device_name = device_name if device_name is not None else ""
ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend",
device_name)
if ret_value != 0:
raise RuntimeError(
f"Mooncake initialization failed with ret_value: {ret_value}")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data.""" """Register the KV Cache data."""
@@ -798,6 +780,8 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_caches = kv_caches self.kv_caches = kv_caches
kv_caches_base_addr = [] kv_caches_base_addr = []
ptrs = []
lengths = []
for cache_or_caches in kv_caches.values(): for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches # Normalize to always be a list of caches
if self.use_mla: if self.use_mla:
@@ -805,7 +789,8 @@ class MooncakeLayerwiseConnectorWorker:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2] region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr) kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len) ptrs.append(base_addr)
lengths.append(region_len)
else: else:
cache_list = [cache_or_caches cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches ] if self.use_mla else cache_or_caches
@@ -813,7 +798,9 @@ class MooncakeLayerwiseConnectorWorker:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0] region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr) kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len) ptrs.append(base_addr)
lengths.append(region_len)
global_te.register_buffer(ptrs, lengths)
self.kv_caches_base_addr = kv_caches_base_addr self.kv_caches_base_addr = kv_caches_base_addr
# After KV Caches registered, start the sending or receiving thread. # After KV Caches registered, start the sending or receiving thread.
@@ -833,6 +820,7 @@ class MooncakeLayerwiseConnectorWorker:
kv_cache_base_addr=self.kv_caches_base_addr, kv_cache_base_addr=self.kv_caches_base_addr,
use_mla=self.use_mla, use_mla=self.use_mla,
block_len=self.block_len, block_len=self.block_len,
decode_tp_size=self._decode_tp_size,
first_kv_cache=first_kv_cache, first_kv_cache=first_kv_cache,
callback_func=self.send_done_send_signal) callback_func=self.send_done_send_signal)
self.kv_send_layer_thread.start() self.kv_send_layer_thread.start()
@@ -846,14 +834,6 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_recv_layer_thread.start() self.kv_recv_layer_thread.start()
ready_event.wait() ready_event.wait()
def _register(self, ptr, length):
logger.info(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
ret_value = self.engine.register_memory(ptr, length)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed.")
def _access_metaserver(self, url, message): def _access_metaserver(self, url, message):
success = False success = False
retry = 0 retry = 0
@@ -969,9 +949,6 @@ class MooncakeLayerwiseConnectorWorker:
key = None key = None
value = None value = None
for req_id, req_meta in connector_metadata.requests.items(): for req_id, req_meta in connector_metadata.requests.items():
logger.debug(
f"Add request {req_id} to kv send layer thread. {req_meta=}"
)
if self.pd_head_ratio != 1: if self.pd_head_ratio != 1:
key_block_num = len( key_block_num = len(
req_meta.local_block_ids) * key_block_size req_meta.local_block_ids) * key_block_size
@@ -983,6 +960,9 @@ class MooncakeLayerwiseConnectorWorker:
key_start_id += key_block_num key_start_id += key_block_num
value_start_id += value_block_num value_start_id += value_block_num
req_meta_update = self.update_decoder_info(req_id, req_meta) req_meta_update = self.update_decoder_info(req_id, req_meta)
logger.debug(
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
assert self.kv_send_layer_thread is not None assert self.kv_send_layer_thread is not None
self.kv_send_layer_thread.send_queue.put( self.kv_send_layer_thread.send_queue.put(
(req_id, req_meta_update, self.current_layer, key, value)) (req_id, req_meta_update, self.current_layer, key, value))
@@ -1011,10 +991,8 @@ class MooncakeLayerwiseConnectorWorker:
def update_decoder_info(self, req_id, req_meta): def update_decoder_info(self, req_id, req_meta):
req_meta_update = copy.deepcopy(req_meta) req_meta_update = copy.deepcopy(req_meta)
if self.pd_tp_ratio > 1: req_meta_update.remote_port = req_meta_update.remote_port + (
req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank // self.pd_tp_ratio self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
else:
req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
try: try: