[P/D] check kv extra config and del hccl backend (#4547)
### What this PR does / why we need it?
check kv extra config & del hccl backend
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -430,7 +430,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
|||||||
"engine_id": "0",
|
"engine_id": "0",
|
||||||
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
||||||
"kv_connector_extra_config": {
|
"kv_connector_extra_config": {
|
||||||
"use_ascend_direct": true,
|
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"dp_size": 2,
|
"dp_size": 2,
|
||||||
"tp_size": 8
|
"tp_size": 8
|
||||||
@@ -510,7 +509,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
|||||||
"engine_id": "1",
|
"engine_id": "1",
|
||||||
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
||||||
"kv_connector_extra_config": {
|
"kv_connector_extra_config": {
|
||||||
"use_ascend_direct": true,
|
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"dp_size": 2,
|
"dp_size": 2,
|
||||||
"tp_size": 8
|
"tp_size": 8
|
||||||
@@ -590,7 +588,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
|||||||
"engine_id": "2",
|
"engine_id": "2",
|
||||||
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
||||||
"kv_connector_extra_config": {
|
"kv_connector_extra_config": {
|
||||||
"use_ascend_direct": true,
|
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"dp_size": 2,
|
"dp_size": 2,
|
||||||
"tp_size": 8
|
"tp_size": 8
|
||||||
@@ -670,7 +667,6 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
|||||||
"engine_id": "3",
|
"engine_id": "3",
|
||||||
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
||||||
"kv_connector_extra_config": {
|
"kv_connector_extra_config": {
|
||||||
"use_ascend_direct": true,
|
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"dp_size": 2,
|
"dp_size": 2,
|
||||||
"tp_size": 8
|
"tp_size": 8
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
|
|||||||
"metadata_server": "P2PHANDSHAKE",
|
"metadata_server": "P2PHANDSHAKE",
|
||||||
"protocol": "ascend",
|
"protocol": "ascend",
|
||||||
"device_name": "",
|
"device_name": "",
|
||||||
"use_ascend_direct": true,
|
|
||||||
"alloc_in_same_node": true,
|
"alloc_in_same_node": true,
|
||||||
"master_server_address": "xx.xx.xx.xx:50088",
|
"master_server_address": "xx.xx.xx.xx:50088",
|
||||||
"global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824)
|
"global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824)
|
||||||
@@ -52,7 +51,6 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
|
|||||||
**metadata_server**: Configured as **P2PHANDSHAKE**.
|
**metadata_server**: Configured as **P2PHANDSHAKE**.
|
||||||
**protocol:** Configured for Ascend to use Mooncake's HCCL communication.
|
**protocol:** Configured for Ascend to use Mooncake's HCCL communication.
|
||||||
**device_name**: ""
|
**device_name**: ""
|
||||||
**use_ascend_direct**: Indicator for using ADXL engine.
|
|
||||||
**alloc_in_same_node**: Indicator for preferring local buffer allocation strategy.
|
**alloc_in_same_node**: Indicator for preferring local buffer allocation strategy.
|
||||||
**master_server_address**: Configured with the IP and port of the master service.
|
**master_server_address**: Configured with the IP and port of the master service.
|
||||||
**global_segment_size**: Expands the kvcache size registered by the PD node to the master.
|
**global_segment_size**: Expands the kvcache size registered by the PD node to the master.
|
||||||
@@ -133,7 +131,7 @@ python3 -m vllm.entrypoints.openai.api_server \
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}' > p.log 2>&1
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
`decode` Node:
|
`decode` Node:
|
||||||
@@ -177,7 +175,6 @@ python3 -m vllm.entrypoints.openai.api_server \
|
|||||||
"kv_role": "kv_consumer",
|
"kv_role": "kv_consumer",
|
||||||
"kv_port": "20002",
|
"kv_port": "20002",
|
||||||
"kv_connector_extra_config": {
|
"kv_connector_extra_config": {
|
||||||
"use_ascend_direct": true,
|
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"dp_size": 1,
|
"dp_size": 1,
|
||||||
"tp_size": 1
|
"tp_size": 1
|
||||||
@@ -196,7 +193,7 @@ python3 -m vllm.entrypoints.openai.api_server \
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}' > d.log 2>&1
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 2、Start proxy_server.
|
#### 2、Start proxy_server.
|
||||||
|
|||||||
@@ -639,10 +639,15 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
config = MockVllmConfig()
|
config = MockVllmConfig()
|
||||||
self.p1 = patch(
|
self.p1 = patch(
|
||||||
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config',
|
||||||
new=MagicMock(return_value=None))
|
new=MagicMock())
|
||||||
|
self.p2 = patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
new=MagicMock(return_value=MagicMock()))
|
||||||
self.p1.start()
|
self.p1.start()
|
||||||
|
self.p2.start()
|
||||||
self.addCleanup(self.p1.stop)
|
self.addCleanup(self.p1.stop)
|
||||||
|
self.addCleanup(self.p2.stop)
|
||||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||||
|
|
||||||
def test_get_num_new_matched_tokens(self):
|
def test_get_num_new_matched_tokens(self):
|
||||||
@@ -716,7 +721,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
|
|||||||
config = MockVllmConfig()
|
config = MockVllmConfig()
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||||
self.assertIsNotNone(connector.connector_scheduler)
|
self.assertIsNotNone(connector.connector_scheduler)
|
||||||
self.assertIsNone(connector.connector_worker)
|
self.assertIsNone(connector.connector_worker)
|
||||||
@@ -726,7 +733,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
|
|||||||
config = MockVllmConfig()
|
config = MockVllmConfig()
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
connector.get_num_new_matched_tokens(request, 0)
|
connector.get_num_new_matched_tokens(request, 0)
|
||||||
@@ -756,7 +765,9 @@ class TestMooncakeConnector(unittest.TestCase):
|
|||||||
def test_scheduler_initialization(self):
|
def test_scheduler_initialization(self):
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
connector = MooncakeConnector(self.config,
|
connector = MooncakeConnector(self.config,
|
||||||
KVConnectorRole.SCHEDULER)
|
KVConnectorRole.SCHEDULER)
|
||||||
self.assertIsNotNone(connector.connector_scheduler)
|
self.assertIsNotNone(connector.connector_scheduler)
|
||||||
@@ -766,7 +777,9 @@ class TestMooncakeConnector(unittest.TestCase):
|
|||||||
def test_get_num_new_matched_tokens(self, mock_method):
|
def test_get_num_new_matched_tokens(self, mock_method):
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
connector = MooncakeConnector(self.config,
|
connector = MooncakeConnector(self.config,
|
||||||
KVConnectorRole.SCHEDULER)
|
KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
@@ -777,7 +790,9 @@ class TestMooncakeConnector(unittest.TestCase):
|
|||||||
def test_update_state_after_alloc(self, mock_method):
|
def test_update_state_after_alloc(self, mock_method):
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
connector = MooncakeConnector(self.config,
|
connector = MooncakeConnector(self.config,
|
||||||
KVConnectorRole.SCHEDULER)
|
KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
@@ -789,7 +804,9 @@ class TestMooncakeConnector(unittest.TestCase):
|
|||||||
def test_build_connector_meta(self, mock_method):
|
def test_build_connector_meta(self, mock_method):
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
connector = MooncakeConnector(self.config,
|
connector = MooncakeConnector(self.config,
|
||||||
KVConnectorRole.SCHEDULER)
|
KVConnectorRole.SCHEDULER)
|
||||||
scheduler_output = MockSchedulerOutput()
|
scheduler_output = MockSchedulerOutput()
|
||||||
@@ -800,7 +817,9 @@ class TestMooncakeConnector(unittest.TestCase):
|
|||||||
def test_request_finished(self, mock_method):
|
def test_request_finished(self, mock_method):
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
connector = MooncakeConnector(self.config,
|
connector = MooncakeConnector(self.config,
|
||||||
KVConnectorRole.SCHEDULER)
|
KVConnectorRole.SCHEDULER)
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
@@ -814,7 +833,9 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
|
|||||||
self.config = MockVllmConfig()
|
self.config = MockVllmConfig()
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||||
):
|
), patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()):
|
||||||
self.scheduler = MooncakeConnectorScheduler(
|
self.scheduler = MooncakeConnectorScheduler(
|
||||||
self.config, "test_engine")
|
self.config, "test_engine")
|
||||||
|
|
||||||
@@ -1037,9 +1058,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
self.mock_pcp_group.device_group = MagicMock()
|
self.mock_pcp_group.device_group = MagicMock()
|
||||||
|
|
||||||
self.patches = [
|
self.patches = [
|
||||||
patch(
|
|
||||||
'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
|
|
||||||
'10,11'),
|
|
||||||
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
||||||
patch('torch.Tensor.element_size', return_value=4),
|
patch('torch.Tensor.element_size', return_value=4),
|
||||||
patch('torch.Tensor.data_ptr', return_value=0x1000),
|
patch('torch.Tensor.data_ptr', return_value=0x1000),
|
||||||
@@ -1056,8 +1074,11 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
|
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
|
||||||
mock_string_to_int64_hash),
|
mock_string_to_int64_hash),
|
||||||
patch(
|
patch(
|
||||||
'vllm_ascend.distributed.mooncake_transfer_engine.TransferEngine',
|
'vllm_ascend.distributed.mooncake_connector.global_te.get_transfer_engine',
|
||||||
return_value=self.mock_transfer_engine),
|
return_value=self.mock_transfer_engine),
|
||||||
|
patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.global_te.register_buffer',
|
||||||
|
return_value=None),
|
||||||
patch(
|
patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
|
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
|
||||||
MagicMock()),
|
MagicMock()),
|
||||||
@@ -1073,10 +1094,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
patch('vllm.distributed.parallel_state._DCP',
|
patch('vllm.distributed.parallel_state._DCP',
|
||||||
return_value=self.mock_dcp),
|
return_value=self.mock_dcp),
|
||||||
patch(
|
patch(
|
||||||
'vllm.distributed.get_decode_context_model_parallel_world_size',
|
'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size',
|
||||||
return_value=1),
|
return_value=1),
|
||||||
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
|
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
|
||||||
return_value=self.mock_pcp_group),
|
return_value=self.mock_pcp_group),
|
||||||
|
patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||||
|
return_value=MagicMock()),
|
||||||
]
|
]
|
||||||
|
|
||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
@@ -1090,46 +1114,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
p.stop() # type: ignore
|
p.stop() # type: ignore
|
||||||
|
|
||||||
def test_worker_use_ascend_direct(self):
|
|
||||||
test_case = [True, False]
|
|
||||||
|
|
||||||
for use_ascend_direct in test_case:
|
|
||||||
with self.subTest(use_ascend_direct=use_ascend_direct):
|
|
||||||
config = MagicMock()
|
|
||||||
config.kv_transfer_config = MagicMock()
|
|
||||||
config.kv_transfer_config.get_from_extra_config.side_effect = (
|
|
||||||
lambda k, d: {
|
|
||||||
"prefill": {
|
|
||||||
"tp_size": 2,
|
|
||||||
"dp_size": 1
|
|
||||||
},
|
|
||||||
"decode": {
|
|
||||||
"tp_size": 2,
|
|
||||||
"dp_size": 1
|
|
||||||
},
|
|
||||||
"use_ascend_direct": use_ascend_direct,
|
|
||||||
}.get(k, d))
|
|
||||||
|
|
||||||
config.parallel_config = MagicMock()
|
|
||||||
config.parallel_config.tensor_parallel_size = 2
|
|
||||||
config.parallel_config.data_parallel_rank = 0
|
|
||||||
config.parallel_config.data_parallel_size_local = 1
|
|
||||||
config.kv_transfer_config.kv_port = 8000
|
|
||||||
config.kv_transfer_config.kv_role = 'worker'
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank",
|
|
||||||
return_value=0):
|
|
||||||
with patch(
|
|
||||||
"vllm_ascend.distributed.mooncake_connector.get_tp_group",
|
|
||||||
return_value=None):
|
|
||||||
with patch(
|
|
||||||
"vllm_ascend.distributed.mooncake_connector.get_ip",
|
|
||||||
return_value="127.0.0.1"):
|
|
||||||
worker = MooncakeConnectorWorker(
|
|
||||||
config, self.engine_id)
|
|
||||||
self.assertIsNotNone(worker)
|
|
||||||
|
|
||||||
def test_register_kv_caches_producer(self):
|
def test_register_kv_caches_producer(self):
|
||||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||||
worker.register_kv_caches(self.kv_caches)
|
worker.register_kv_caches(self.kv_caches)
|
||||||
@@ -1160,7 +1144,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
# Test with physical devices set
|
# Test with physical devices set
|
||||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||||
# Default tp_rank is 0, so device_id should be 10
|
# Default tp_rank is 0, so device_id should be 10
|
||||||
self.assertEqual(worker.device_id, 10)
|
self.assertIsNotNone(worker.engine)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
6000], # 2 * total_layers
|
6000], # 2 * total_layers
|
||||||
use_mla=True,
|
use_mla=True,
|
||||||
block_len=[1024, 2048],
|
block_len=[1024, 2048],
|
||||||
|
decode_tp_size=1,
|
||||||
first_kv_cache=self.first_kv_cache,
|
first_kv_cache=self.first_kv_cache,
|
||||||
callback_func=MagicMock())
|
callback_func=MagicMock())
|
||||||
|
|
||||||
@@ -97,6 +98,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
kv_cache_base_addr=[1111, 2222, 3333, 4444],
|
kv_cache_base_addr=[1111, 2222, 3333, 4444],
|
||||||
use_mla=False,
|
use_mla=False,
|
||||||
block_len=[64],
|
block_len=[64],
|
||||||
|
decode_tp_size=1,
|
||||||
first_kv_cache=self.first_kv_cache,
|
first_kv_cache=self.first_kv_cache,
|
||||||
callback_func=MagicMock())
|
callback_func=MagicMock())
|
||||||
|
|
||||||
@@ -155,6 +157,7 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
kv_cache_base_addr=[1000, 2000],
|
kv_cache_base_addr=[1000, 2000],
|
||||||
use_mla=False,
|
use_mla=False,
|
||||||
block_len=[1024],
|
block_len=[1024],
|
||||||
|
decode_tp_size=1,
|
||||||
first_kv_cache=self.first_kv_cache,
|
first_kv_cache=self.first_kv_cache,
|
||||||
callback_func=MagicMock())
|
callback_func=MagicMock())
|
||||||
req_meta = self.req_meta_base
|
req_meta = self.req_meta_base
|
||||||
@@ -397,7 +400,6 @@ class MockVllmConfig:
|
|||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"dp_size": 1
|
"dp_size": 1
|
||||||
},
|
},
|
||||||
"use_ascend_direct": True,
|
|
||||||
}.get(k, d)
|
}.get(k, d)
|
||||||
|
|
||||||
|
|
||||||
@@ -806,9 +808,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
|
|||||||
self.mock_transfer_engine.register_memory.return_value = 0
|
self.mock_transfer_engine.register_memory.return_value = 0
|
||||||
|
|
||||||
self.patches = [
|
self.patches = [
|
||||||
patch(
|
|
||||||
'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
|
|
||||||
'10,11'),
|
|
||||||
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
||||||
patch('torch.Tensor.element_size', return_value=4),
|
patch('torch.Tensor.element_size', return_value=4),
|
||||||
patch('torch.Tensor.data_ptr', return_value=0x1000),
|
patch('torch.Tensor.data_ptr', return_value=0x1000),
|
||||||
@@ -827,8 +826,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
|
|||||||
'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash',
|
'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash',
|
||||||
side_effect=lambda s: hash(s)),
|
side_effect=lambda s: hash(s)),
|
||||||
patch(
|
patch(
|
||||||
'vllm_ascend.distributed.mooncake_layerwise_connector.TransferEngine',
|
'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.get_transfer_engine',
|
||||||
return_value=self.mock_transfer_engine),
|
return_value=self.mock_transfer_engine),
|
||||||
|
patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.register_buffer',
|
||||||
|
return_value=None),
|
||||||
patch(
|
patch(
|
||||||
'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheSendingLayerThread',
|
'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheSendingLayerThread',
|
||||||
MagicMock()),
|
MagicMock()),
|
||||||
@@ -859,26 +861,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
|
|||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
p.stop() # type: ignore
|
p.stop() # type: ignore
|
||||||
|
|
||||||
def test_worker_use_ascend_direct(self):
|
|
||||||
for use_ascend_direct in (True, False):
|
|
||||||
with self.subTest(use_ascend_direct=use_ascend_direct):
|
|
||||||
config = MockVllmConfig()
|
|
||||||
config.kv_transfer_config.get_from_extra_config.side_effect = (
|
|
||||||
lambda k, d: {
|
|
||||||
"prefill": {
|
|
||||||
"tp_size": 2,
|
|
||||||
"dp_size": 1
|
|
||||||
},
|
|
||||||
"decode": {
|
|
||||||
"tp_size": 2,
|
|
||||||
"dp_size": 1
|
|
||||||
},
|
|
||||||
"use_ascend_direct": use_ascend_direct,
|
|
||||||
}.get(k, d))
|
|
||||||
worker = MooncakeLayerwiseConnectorWorker(
|
|
||||||
config, self.engine_id)
|
|
||||||
self.assertIsNotNone(worker)
|
|
||||||
|
|
||||||
def test_register_kv_caches_producer(self):
|
def test_register_kv_caches_producer(self):
|
||||||
|
|
||||||
self.vllm_config.kv_transfer_config.is_kv_producer = True
|
self.vllm_config.kv_transfer_config.is_kv_producer = True
|
||||||
@@ -915,7 +897,7 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
|
|||||||
def test_device_id_selection_with_physical_devices(self):
|
def test_device_id_selection_with_physical_devices(self):
|
||||||
worker = MooncakeLayerwiseConnectorWorker(self.vllm_config,
|
worker = MooncakeLayerwiseConnectorWorker(self.vllm_config,
|
||||||
self.engine_id)
|
self.engine_id)
|
||||||
self.assertEqual(worker.device_id, 10)
|
self.assertIsNotNone(worker.engine)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -28,6 +28,37 @@ def _check_torchair_supported(model_type: str):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_kv_extra_config(vllm_config):
|
||||||
|
|
||||||
|
def _check(name: str, config: dict):
|
||||||
|
tp_key = "tp_size"
|
||||||
|
dp_key = "dp_size"
|
||||||
|
if tp_key in config:
|
||||||
|
config_tp = config[tp_key]
|
||||||
|
vllm_tp = vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
if config_tp != vllm_tp:
|
||||||
|
raise ValueError(
|
||||||
|
f"KV transfer '{name}' config has a conflicting tensor parallel size. "
|
||||||
|
f"Expected {vllm_tp}, but got {config_tp}.")
|
||||||
|
if dp_key in config:
|
||||||
|
config_dp = config[dp_key]
|
||||||
|
vllm_dp = vllm_config.parallel_config.data_parallel_size
|
||||||
|
if config_dp != vllm_dp:
|
||||||
|
raise ValueError(
|
||||||
|
f"KV transfer '{name}' config has a conflicting data parallel size. "
|
||||||
|
f"Expected {vllm_dp}, but got {config_dp}.")
|
||||||
|
|
||||||
|
if vllm_config.kv_transfer_config.is_kv_producer:
|
||||||
|
_check(
|
||||||
|
"prefill",
|
||||||
|
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
|
"prefill", {}))
|
||||||
|
if vllm_config.kv_transfer_config.is_kv_consumer:
|
||||||
|
_check(
|
||||||
|
"decode",
|
||||||
|
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
|
||||||
|
|
||||||
|
|
||||||
class AscendConfig:
|
class AscendConfig:
|
||||||
"""
|
"""
|
||||||
Configuration Object for additional_config from vllm.configs.
|
Configuration Object for additional_config from vllm.configs.
|
||||||
@@ -112,6 +143,10 @@ class AscendConfig:
|
|||||||
)
|
)
|
||||||
self.enable_cpu_binding = additional_config.get(
|
self.enable_cpu_binding = additional_config.get(
|
||||||
"enable_cpu_binding", False)
|
"enable_cpu_binding", False)
|
||||||
|
|
||||||
|
if vllm_config.kv_transfer_config is not None:
|
||||||
|
check_kv_extra_config(vllm_config)
|
||||||
|
|
||||||
self.pd_tp_ratio = 1
|
self.pd_tp_ratio = 1
|
||||||
self.pd_head_ratio = 1
|
self.pd_head_ratio = 1
|
||||||
self.num_head_replica = 1
|
self.num_head_replica = 1
|
||||||
|
|||||||
@@ -83,7 +83,6 @@ class MooncakeStoreConfig:
|
|||||||
protocol: str
|
protocol: str
|
||||||
device_name: str
|
device_name: str
|
||||||
master_server_address: str
|
master_server_address: str
|
||||||
use_ascend_direct: bool
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||||
@@ -99,8 +98,7 @@ class MooncakeStoreConfig:
|
|||||||
DEFAULT_LOCAL_BUFFER_SIZE)),
|
DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||||
protocol=config.get("protocol", "tcp"),
|
protocol=config.get("protocol", "tcp"),
|
||||||
device_name=config.get("device_name", ""),
|
device_name=config.get("device_name", ""),
|
||||||
master_server_address=config.get("master_server_address"),
|
master_server_address=config.get("master_server_address"))
|
||||||
use_ascend_direct=config.get("use_ascend_direct", False))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_from_env() -> "MooncakeStoreConfig":
|
def load_from_env() -> "MooncakeStoreConfig":
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import RequestStatus
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||||
@@ -653,6 +652,7 @@ class MooncakeConnector(KVConnectorBase_V1):
|
|||||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||||
|
self._connector_metadata = MooncakeConnectorMetadata()
|
||||||
|
|
||||||
if role == KVConnectorRole.SCHEDULER:
|
if role == KVConnectorRole.SCHEDULER:
|
||||||
self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \
|
self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \
|
||||||
@@ -744,9 +744,6 @@ class MooncakeConnectorScheduler:
|
|||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
|
|
||||||
vllm_config.parallel_config.data_parallel_size * \
|
|
||||||
self.pcp_size
|
|
||||||
|
|
||||||
# Handshake base port
|
# Handshake base port
|
||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
@@ -905,8 +902,6 @@ class MooncakeConnectorWorker:
|
|||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
|
|
||||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.pcp_size = get_pcp_group().world_size
|
self.pcp_size = get_pcp_group().world_size
|
||||||
@@ -916,7 +911,6 @@ class MooncakeConnectorWorker:
|
|||||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||||
) if self.dcp_size > 1 else 0
|
) if self.dcp_size > 1 else 0
|
||||||
|
|
||||||
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size
|
|
||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||||
|
|
||||||
@@ -927,38 +921,9 @@ class MooncakeConnectorWorker:
|
|||||||
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
|
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
|
||||||
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
|
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
|
||||||
self.sockets: dict = {}
|
self.sockets: dict = {}
|
||||||
|
|
||||||
# get tp device id
|
|
||||||
# TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
|
|
||||||
# introducing some changes
|
|
||||||
device_ids_str = envs_ascend.PHYSICAL_DEVICES
|
|
||||||
if device_ids_str is None:
|
|
||||||
device_ids = list(
|
|
||||||
range(self.dp_rank * self.tp_size * self.pcp_size,
|
|
||||||
(self.dp_rank + 1) * self.tp_size * self.pcp_size))
|
|
||||||
else:
|
|
||||||
device_ids = list(map(int, device_ids_str.split(',')))
|
|
||||||
start_index = self.dp_rank * self.tp_size * self.pcp_size
|
|
||||||
end_index = start_index + self.tp_size * self.pcp_size
|
|
||||||
if len(device_ids) < end_index:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough physical devices available for DP rank {self.dp_rank}. "
|
|
||||||
f"Expected at least {end_index} devices, but found {len(device_ids)} "
|
|
||||||
"in PHYSICAL_DEVICES.")
|
|
||||||
device_ids = device_ids[start_index:end_index]
|
|
||||||
assert len(
|
|
||||||
device_ids
|
|
||||||
) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore
|
|
||||||
self.device_id = device_ids[self.pcp_rank * self.tp_size +
|
|
||||||
self.tp_rank] # type: ignore
|
|
||||||
|
|
||||||
if vllm_config.kv_transfer_config.get_from_extra_config(
|
|
||||||
'use_ascend_direct', True):
|
|
||||||
hostname = self.side_channel_host
|
|
||||||
else:
|
|
||||||
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
|
||||||
logger.info("Initializing Mooncake work %s", engine_id)
|
logger.info("Initializing Mooncake work %s", engine_id)
|
||||||
self.engine = global_te.get_transfer_engine(hostname, device_name=None)
|
self.engine = global_te.get_transfer_engine(self.side_channel_host,
|
||||||
|
device_name=None)
|
||||||
self.te_rpc_port = self.engine.get_rpc_port()
|
self.te_rpc_port = self.engine.get_rpc_port()
|
||||||
|
|
||||||
# Background thread for sending or receiving KV caches.
|
# Background thread for sending or receiving KV caches.
|
||||||
@@ -998,19 +963,6 @@ class MooncakeConnectorWorker:
|
|||||||
assert "dp_size" in decode_parallel_config.keys()
|
assert "dp_size" in decode_parallel_config.keys()
|
||||||
self._decode_dp_size = decode_parallel_config["dp_size"]
|
self._decode_dp_size = decode_parallel_config["dp_size"]
|
||||||
|
|
||||||
def _initialize(
|
|
||||||
self,
|
|
||||||
hostname: str,
|
|
||||||
device_name: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the mooncake instance."""
|
|
||||||
device_name = device_name if device_name is not None else ""
|
|
||||||
ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend",
|
|
||||||
device_name)
|
|
||||||
if ret_value != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Mooncake initialization failed with ret_value: {ret_value}")
|
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
"""Register the KV Cache data."""
|
"""Register the KV Cache data."""
|
||||||
|
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||||
from vllm_ascend.distributed.utils import (align_memory,
|
from vllm_ascend.distributed.utils import (align_memory,
|
||||||
get_transfer_timeout_value,
|
get_transfer_timeout_value,
|
||||||
kv_alltoall_and_rearrange)
|
kv_alltoall_and_rearrange)
|
||||||
@@ -100,6 +100,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
kv_cache_base_addr: list[int],
|
kv_cache_base_addr: list[int],
|
||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
block_len: list[int],
|
block_len: list[int],
|
||||||
|
decode_tp_size: int,
|
||||||
first_kv_cache: torch.Tensor,
|
first_kv_cache: torch.Tensor,
|
||||||
callback_func: Callable[..., None] = lambda x: None):
|
callback_func: Callable[..., None] = lambda x: None):
|
||||||
super().__init__(daemon=True, name="KVCacheSendingLayerThread")
|
super().__init__(daemon=True, name="KVCacheSendingLayerThread")
|
||||||
@@ -111,6 +112,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
self.total_layers = total_layers
|
self.total_layers = total_layers
|
||||||
self.use_mla = use_mla
|
self.use_mla = use_mla
|
||||||
self.block_len = block_len
|
self.block_len = block_len
|
||||||
|
self._decode_tp_size = decode_tp_size
|
||||||
self.model_stream = torch_npu.npu.current_stream()
|
self.model_stream = torch_npu.npu.current_stream()
|
||||||
self.current_layer = -1
|
self.current_layer = -1
|
||||||
|
|
||||||
@@ -172,9 +174,20 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
|
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
|
||||||
# send kv layer to remote
|
# send kv layer to remote
|
||||||
if len(req_meta.local_block_ids) == 0:
|
if len(req_meta.local_block_ids) == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
# not need to send kv cache
|
# not need to send kv cache
|
||||||
if self.tp_rank % self.num_head_replica != 0:
|
if self.tp_rank % self.num_head_replica != 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if self.use_mla and self.tp_rank >= self._decode_tp_size:
|
||||||
|
logger.debug(
|
||||||
|
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
remote_host = req_meta.remote_host
|
remote_host = req_meta.remote_host
|
||||||
@@ -484,8 +497,6 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
||||||
|
|
||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
|
|
||||||
vllm_config.parallel_config.data_parallel_size
|
|
||||||
|
|
||||||
# Handshake base port
|
# Handshake base port
|
||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
@@ -550,6 +561,9 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
local_block_ids = (blocks.get_unhashed_block_ids()
|
local_block_ids = (blocks.get_unhashed_block_ids()
|
||||||
if num_external_tokens > 0 else [])
|
if num_external_tokens > 0 else [])
|
||||||
# Get unhashed blocks to pull from remote.
|
# Get unhashed blocks to pull from remote.
|
||||||
|
logger.debug(
|
||||||
|
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue"
|
||||||
|
)
|
||||||
self._reqs_need_recv[request.request_id] = (
|
self._reqs_need_recv[request.request_id] = (
|
||||||
request,
|
request,
|
||||||
[], #request._all_token_ids,
|
[], #request._all_token_ids,
|
||||||
@@ -560,6 +574,9 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
# Layerwise prefiller add request need send
|
# Layerwise prefiller add request need send
|
||||||
if params is not None and params.get("do_remote_decode"):
|
if params is not None and params.get("do_remote_decode"):
|
||||||
local_block_ids = (blocks.get_block_ids()[0])
|
local_block_ids = (blocks.get_block_ids()[0])
|
||||||
|
logger.debug(
|
||||||
|
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
|
||||||
|
)
|
||||||
self._reqs_need_send_layerwise[request.request_id] = (len(
|
self._reqs_need_send_layerwise[request.request_id] = (len(
|
||||||
request.all_token_ids), local_block_ids, request)
|
request.all_token_ids), local_block_ids, request)
|
||||||
|
|
||||||
@@ -603,12 +620,19 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
req_id]
|
req_id]
|
||||||
current_tokens = computed_tokens.get(req_id,
|
current_tokens = computed_tokens.get(req_id,
|
||||||
0) + scheduled_tokens
|
0) + scheduled_tokens
|
||||||
if current_tokens == total_tokens:
|
if current_tokens >= total_tokens:
|
||||||
|
logger.debug(
|
||||||
|
f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
|
||||||
|
)
|
||||||
meta.add_new_req(request_id=req_id,
|
meta.add_new_req(request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
token_ids=[])
|
token_ids=[])
|
||||||
self._reqs_need_send_layerwise.pop(req_id)
|
self._reqs_need_send_layerwise.pop(req_id)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
|
||||||
|
)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
def request_finished(
|
def request_finished(
|
||||||
@@ -639,7 +663,6 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
if TransferEngine is None:
|
if TransferEngine is None:
|
||||||
raise RuntimeError("mooncake is not available")
|
raise RuntimeError("mooncake is not available")
|
||||||
logger.info("Initializing Mooncake work %s", engine_id)
|
logger.info("Initializing Mooncake work %s", engine_id)
|
||||||
self.engine = TransferEngine()
|
|
||||||
|
|
||||||
# Metadata.
|
# Metadata.
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@@ -648,11 +671,8 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
|
|
||||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.max_device_id = self.tp_size * self.dp_size
|
|
||||||
self.total_layers = vllm_config.model_config.get_num_layers(
|
self.total_layers = vllm_config.model_config.get_num_layers(
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
|
|
||||||
@@ -668,34 +688,9 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
vllm_config.parallel_config.tensor_parallel_size)
|
vllm_config.parallel_config.tensor_parallel_size)
|
||||||
self.handshake_port = self.side_channel_port + self.tp_rank
|
self.handshake_port = self.side_channel_port + self.tp_rank
|
||||||
self.sockets: dict = {}
|
self.sockets: dict = {}
|
||||||
|
logger.info("Initializing Mooncake work %s", engine_id)
|
||||||
# get tp device id
|
self.engine = global_te.get_transfer_engine(self.side_channel_host,
|
||||||
# TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
|
device_name=None)
|
||||||
# introducing some changes
|
|
||||||
device_ids_str = envs_ascend.PHYSICAL_DEVICES
|
|
||||||
if device_ids_str is None:
|
|
||||||
device_ids = list(
|
|
||||||
range(self.dp_rank * self.tp_size,
|
|
||||||
(self.dp_rank + 1) * self.tp_size))
|
|
||||||
else:
|
|
||||||
device_ids = list(map(int, device_ids_str.split(',')))
|
|
||||||
start_index = self.dp_rank * self.tp_size
|
|
||||||
end_index = start_index + self.tp_size
|
|
||||||
if len(device_ids) < end_index:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough physical devices available for DP rank {self.dp_rank}. "
|
|
||||||
f"Expected at least {end_index} devices, but found {len(device_ids)} "
|
|
||||||
"in PHYSICAL_DEVICES.")
|
|
||||||
device_ids = device_ids[start_index:end_index]
|
|
||||||
assert len(device_ids) > self.tp_rank # type: ignore
|
|
||||||
self.device_id = device_ids[self.tp_rank] # type: ignore
|
|
||||||
|
|
||||||
if vllm_config.kv_transfer_config.get_from_extra_config(
|
|
||||||
'use_ascend_direct', True):
|
|
||||||
hostname = self.side_channel_host
|
|
||||||
else:
|
|
||||||
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
|
||||||
self._initialize(hostname=hostname, device_name=None)
|
|
||||||
self.te_rpc_port = self.engine.get_rpc_port()
|
self.te_rpc_port = self.engine.get_rpc_port()
|
||||||
|
|
||||||
# Background thread for sending or receiving KV caches.
|
# Background thread for sending or receiving KV caches.
|
||||||
@@ -747,19 +742,6 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
assert "dp_size" in decode_parallel_config.keys()
|
assert "dp_size" in decode_parallel_config.keys()
|
||||||
self._decode_dp_size = decode_parallel_config["dp_size"]
|
self._decode_dp_size = decode_parallel_config["dp_size"]
|
||||||
|
|
||||||
def _initialize(
|
|
||||||
self,
|
|
||||||
hostname: str,
|
|
||||||
device_name: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the mooncake instance."""
|
|
||||||
device_name = device_name if device_name is not None else ""
|
|
||||||
ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend",
|
|
||||||
device_name)
|
|
||||||
if ret_value != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Mooncake initialization failed with ret_value: {ret_value}")
|
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
"""Register the KV Cache data."""
|
"""Register the KV Cache data."""
|
||||||
|
|
||||||
@@ -798,6 +780,8 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
|
|
||||||
self.kv_caches = kv_caches
|
self.kv_caches = kv_caches
|
||||||
kv_caches_base_addr = []
|
kv_caches_base_addr = []
|
||||||
|
ptrs = []
|
||||||
|
lengths = []
|
||||||
for cache_or_caches in kv_caches.values():
|
for cache_or_caches in kv_caches.values():
|
||||||
# Normalize to always be a list of caches
|
# Normalize to always be a list of caches
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
@@ -805,7 +789,8 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
region_len = self.num_blocks * self.block_len[i % 2]
|
region_len = self.num_blocks * self.block_len[i % 2]
|
||||||
kv_caches_base_addr.append(base_addr)
|
kv_caches_base_addr.append(base_addr)
|
||||||
self._register(base_addr, region_len)
|
ptrs.append(base_addr)
|
||||||
|
lengths.append(region_len)
|
||||||
else:
|
else:
|
||||||
cache_list = [cache_or_caches
|
cache_list = [cache_or_caches
|
||||||
] if self.use_mla else cache_or_caches
|
] if self.use_mla else cache_or_caches
|
||||||
@@ -813,7 +798,9 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
region_len = self.num_blocks * self.block_len[0]
|
region_len = self.num_blocks * self.block_len[0]
|
||||||
kv_caches_base_addr.append(base_addr)
|
kv_caches_base_addr.append(base_addr)
|
||||||
self._register(base_addr, region_len)
|
ptrs.append(base_addr)
|
||||||
|
lengths.append(region_len)
|
||||||
|
global_te.register_buffer(ptrs, lengths)
|
||||||
self.kv_caches_base_addr = kv_caches_base_addr
|
self.kv_caches_base_addr = kv_caches_base_addr
|
||||||
|
|
||||||
# After KV Caches registered, start the sending or receiving thread.
|
# After KV Caches registered, start the sending or receiving thread.
|
||||||
@@ -833,6 +820,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
kv_cache_base_addr=self.kv_caches_base_addr,
|
kv_cache_base_addr=self.kv_caches_base_addr,
|
||||||
use_mla=self.use_mla,
|
use_mla=self.use_mla,
|
||||||
block_len=self.block_len,
|
block_len=self.block_len,
|
||||||
|
decode_tp_size=self._decode_tp_size,
|
||||||
first_kv_cache=first_kv_cache,
|
first_kv_cache=first_kv_cache,
|
||||||
callback_func=self.send_done_send_signal)
|
callback_func=self.send_done_send_signal)
|
||||||
self.kv_send_layer_thread.start()
|
self.kv_send_layer_thread.start()
|
||||||
@@ -846,14 +834,6 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.kv_recv_layer_thread.start()
|
self.kv_recv_layer_thread.start()
|
||||||
ready_event.wait()
|
ready_event.wait()
|
||||||
|
|
||||||
def _register(self, ptr, length):
|
|
||||||
logger.info(
|
|
||||||
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
|
||||||
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
|
||||||
ret_value = self.engine.register_memory(ptr, length)
|
|
||||||
if ret_value != 0:
|
|
||||||
raise RuntimeError("Mooncake memory registration failed.")
|
|
||||||
|
|
||||||
def _access_metaserver(self, url, message):
|
def _access_metaserver(self, url, message):
|
||||||
success = False
|
success = False
|
||||||
retry = 0
|
retry = 0
|
||||||
@@ -969,9 +949,6 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
key = None
|
key = None
|
||||||
value = None
|
value = None
|
||||||
for req_id, req_meta in connector_metadata.requests.items():
|
for req_id, req_meta in connector_metadata.requests.items():
|
||||||
logger.debug(
|
|
||||||
f"Add request {req_id} to kv send layer thread. {req_meta=}"
|
|
||||||
)
|
|
||||||
if self.pd_head_ratio != 1:
|
if self.pd_head_ratio != 1:
|
||||||
key_block_num = len(
|
key_block_num = len(
|
||||||
req_meta.local_block_ids) * key_block_size
|
req_meta.local_block_ids) * key_block_size
|
||||||
@@ -983,6 +960,9 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
key_start_id += key_block_num
|
key_start_id += key_block_num
|
||||||
value_start_id += value_block_num
|
value_start_id += value_block_num
|
||||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
||||||
|
logger.debug(
|
||||||
|
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||||
|
)
|
||||||
assert self.kv_send_layer_thread is not None
|
assert self.kv_send_layer_thread is not None
|
||||||
self.kv_send_layer_thread.send_queue.put(
|
self.kv_send_layer_thread.send_queue.put(
|
||||||
(req_id, req_meta_update, self.current_layer, key, value))
|
(req_id, req_meta_update, self.current_layer, key, value))
|
||||||
@@ -1011,10 +991,8 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
|
|
||||||
def update_decoder_info(self, req_id, req_meta):
|
def update_decoder_info(self, req_id, req_meta):
|
||||||
req_meta_update = copy.deepcopy(req_meta)
|
req_meta_update = copy.deepcopy(req_meta)
|
||||||
if self.pd_tp_ratio > 1:
|
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||||
req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank // self.pd_tp_ratio
|
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
||||||
else:
|
|
||||||
req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank
|
|
||||||
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
|
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
|
||||||
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
|
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user