fix mooncake connector adxl hostname usage (#2824)
### What this PR does / why we need it?
This PR is used to adapt the hostname format for Mooncake when using
adxl. When Mooncake uses adxl, it is necessary to set
```USE_ASCEND_DIRECT``` to True in the file
```/Mooncake/mooncake-common/common.cmake``` during compilation. The
mooncake_connector obtains this config by calling
```vllm_config.kv_transfer_config.get_from_extra_config```, determines
whether Mooncake is using adxl, and selects the corresponding hostname
format.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: main
- vLLM main:
d21a36f5f9
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
@@ -961,6 +961,46 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
for p in self.patches:
|
||||
p.stop() # type: ignore
|
||||
|
||||
def test_worker_use_ascend_direct(self):
|
||||
test_case = [True, False]
|
||||
|
||||
for use_ascend_direct in test_case:
|
||||
with self.subTest(use_ascend_direct=use_ascend_direct):
|
||||
config = MagicMock()
|
||||
config.kv_transfer_config = MagicMock()
|
||||
config.kv_transfer_config.get_from_extra_config.side_effect = (
|
||||
lambda k, d: {
|
||||
"prefill": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"decode": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"use_ascend_direct": use_ascend_direct,
|
||||
}.get(k, d))
|
||||
|
||||
config.parallel_config = MagicMock()
|
||||
config.parallel_config.tensor_parallel_size = 2
|
||||
config.parallel_config.data_parallel_rank_local = 0
|
||||
config.parallel_config.data_parallel_size_local = 1
|
||||
config.kv_transfer_config.kv_port = 8000
|
||||
config.kv_transfer_config.kv_role = 'worker'
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank",
|
||||
return_value=0):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tp_group",
|
||||
return_value=None):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_ip",
|
||||
return_value="127.0.0.1"):
|
||||
worker = MooncakeConnectorWorker(
|
||||
config, self.engine_id)
|
||||
self.assertIsNotNone(worker)
|
||||
|
||||
def test_register_kv_caches_producer(self):
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
|
||||
@@ -782,10 +782,12 @@ class MooncakeConnectorWorker:
|
||||
assert len(device_ids) > self.tp_rank # type: ignore
|
||||
self.device_id = device_ids[self.tp_rank] # type: ignore
|
||||
|
||||
self._initialize(
|
||||
hostname=self.side_channel_host + ':' + '0' + ':' + 'npu_' \
|
||||
+ str(self.device_id),
|
||||
device_name=None)
|
||||
if vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
'use_ascend_direct', False):
|
||||
hostname = self.side_channel_host
|
||||
else:
|
||||
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
||||
self._initialize(hostname=hostname, device_name=None)
|
||||
self.te_rpc_port = self.engine.get_rpc_port()
|
||||
|
||||
# Background thread for sending or receiving KV caches.
|
||||
|
||||
Reference in New Issue
Block a user