[Bugfix][PD] Make multiple Ps and Ds work on a single machine (#2080)
(cherry picked from commit 816375e0c1071d0696dfab1a1ce35674f9f37aa0)
### What this PR does / why we need it?
Suppose that you want to start a prefiller instance with npus `2,3`
only. So you start the instance with `ASCEND_RT_VISIBLE_DEVICES=2,3`.
The current programming will start two workers, whose ranks are `0` and
`1` respectedly. And they will pick the first and second ip addresses of
npus in the ranktable instead of the thirdth and forth ones. But
actually they are using card `2,3` and therefore they can not link with
remote instances when they attempt to transfer the KVCache.
Hence, at most 1 prefiller instance and at most 1 decoder instance can
work on a single machine since they always pick the first npu ip address
in the ranktable currently.
This pull request is proposed to fix the problem. This fix pick ips of
only those devices that are in `ASCEND_RT_VISIBLE_DEVICES` from the
ranktable.
### Does this PR introduce _any_ user-facing change?
If the user use ranktable generated by `gen_ranktable.sh`, they should
not face any change.
### How was this patch tested?
Qwen-0.6B 1P 1D, dp=2, `ASCEND_RT_VISIBLE_DEVICES=2,3` for prefiller and
`ASCEND_RT_VISIBLE_DEVICES=4,5` for decoder.
- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a
Signed-off-by: CaveNightingale <cavenightingale@foxmail.com>
This commit is contained in:
@@ -2,10 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
import os
|
||||
import types
|
||||
|
||||
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import \
|
||||
LLMDataDistCMgrConnectorMetadata
|
||||
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import (
|
||||
LLMDataDistCMgrConnectorMetadata, LLMDataDistCMgrConnectorWorker, LLMRole)
|
||||
|
||||
|
||||
def test_basic_inferface():
|
||||
@@ -40,3 +43,54 @@ def test_basic_inferface():
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id]):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
def test_read_agent_metadata():
|
||||
rank_table = {
|
||||
"version":
|
||||
"1.2",
|
||||
"server_count":
|
||||
"2",
|
||||
"prefill_device_list": [{
|
||||
"server_id": "192.168.1.1",
|
||||
"device_id": "0",
|
||||
"device_ip": "10.30.0.1",
|
||||
"cluster_id": "0",
|
||||
}, {
|
||||
"server_id": "192.168.1.1",
|
||||
"device_id": "1",
|
||||
"device_ip": "10.30.0.2",
|
||||
"cluster_id": "1",
|
||||
}, {
|
||||
"server_id": "192.168.1.2",
|
||||
"device_id": "0",
|
||||
"device_ip": "10.30.0.3",
|
||||
"cluster_id": "2",
|
||||
}, {
|
||||
"server_id": "192.168.1.2",
|
||||
"device_id": "1",
|
||||
"device_ip": "10.30.0.4",
|
||||
"cluster_id": "3",
|
||||
}]
|
||||
}
|
||||
|
||||
def get_device_ip(worker_local_ip, worker_tp_rank, worker_visible_devices):
|
||||
old_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
worker = types.SimpleNamespace()
|
||||
worker.local_ip = worker_local_ip
|
||||
worker.tp_rank = worker_tp_rank
|
||||
worker.llm_datadist_role = LLMRole.PROMPT
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
|
||||
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
|
||||
worker, rank_table)
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = old_visible_devices
|
||||
return agent_metadata.device_ip
|
||||
|
||||
assert get_device_ip("192.168.1.1", 0, "0") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 0, "1") == "10.30.0.2"
|
||||
assert get_device_ip("192.168.1.2", 0, "0") == "10.30.0.3"
|
||||
assert get_device_ip("192.168.1.2", 0, "1") == "10.30.0.4"
|
||||
assert get_device_ip("192.168.1.1", 0, "0,1") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 1, "0,1") == "10.30.0.2"
|
||||
assert get_device_ip("192.168.1.1", 0, "") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 1, "") == "10.30.0.2"
|
||||
|
||||
@@ -9,7 +9,7 @@ from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import msgspec
|
||||
@@ -331,9 +331,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
self.prefill_device_list: list[tuple[int, int]] = []
|
||||
self.decode_device_list: list[tuple[int, int]] = []
|
||||
global_rank_table = self.read_offline_rank_table()
|
||||
self.local_agent_metadata = self.read_agent_metadata(
|
||||
global_rank_table, self.local_ip, self.local_rank_on_node,
|
||||
self.llm_datadist_role)
|
||||
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
|
||||
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
|
||||
self.local_agent_metadata.cluster_id)
|
||||
self.init_llm_datadist()
|
||||
@@ -448,8 +446,20 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
# global_rank_table = json.dumps(global_rank_table)
|
||||
return global_rank_table
|
||||
|
||||
def read_agent_metadata(self, global_rank_table, server_id, device_rank,
|
||||
agent_role):
|
||||
@staticmethod
|
||||
def _get_visible_devices() -> Callable[[str], bool]:
|
||||
"""
|
||||
Return a test function that check if the given device ID is visible.
|
||||
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
|
||||
"""
|
||||
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
if not visible_devices:
|
||||
return lambda device_id: True
|
||||
visible_device_list = visible_devices.split(",")
|
||||
return lambda device_id: device_id in visible_device_list
|
||||
|
||||
def read_agent_metadata(self, global_rank_table):
|
||||
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
|
||||
devices_type_list = []
|
||||
agent_metadata = None
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
@@ -462,11 +472,12 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
for device_type in devices_type_list:
|
||||
device_list = global_rank_table[device_type]
|
||||
device_list = [
|
||||
d for d in device_list if d.get("server_id") == server_id
|
||||
d for d in device_list if d.get("server_id") == self.local_ip
|
||||
and device_filter(d.get("device_id", ""))
|
||||
]
|
||||
if len(device_list) <= device_rank:
|
||||
if len(device_list) <= self.tp_rank:
|
||||
continue
|
||||
device_info = device_list[device_rank]
|
||||
device_info = device_list[self.tp_rank]
|
||||
super_pod_id_ = device_info.get("super_pod_id", None)
|
||||
server_id_ = device_info["server_id"]
|
||||
device_id_ = device_info["device_id"]
|
||||
@@ -481,7 +492,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
super_device_id=super_device_id_,
|
||||
cluster_id=cluster_id_,
|
||||
)
|
||||
assert agent_metadata is not None, f"Can't read the target server_id {server_id} and device_rank {device_rank} from rank table"
|
||||
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
|
||||
return agent_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user