diff --git a/tests/ut/kv_connector/test_llmdatadist_connector.py b/tests/ut/kv_connector/test_llmdatadist_connector.py index 94650f4..b70482f 100644 --- a/tests/ut/kv_connector/test_llmdatadist_connector.py +++ b/tests/ut/kv_connector/test_llmdatadist_connector.py @@ -2,10 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +import os +import types + from tests.ut.kv_connector.utils import (create_request, create_scheduler, create_vllm_config) -from vllm_ascend.distributed.llmdatadist_c_mgr_connector import \ - LLMDataDistCMgrConnectorMetadata +from vllm_ascend.distributed.llmdatadist_c_mgr_connector import ( + LLMDataDistCMgrConnectorMetadata, LLMDataDistCMgrConnectorWorker, LLMRole) def test_basic_inferface(): @@ -40,3 +43,54 @@ def test_basic_inferface(): req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. single_type_managers[0].req_to_blocks[request_id]): assert block_id == block.block_id + + +def test_read_agent_metadata(): + rank_table = { + "version": + "1.2", + "server_count": + "2", + "prefill_device_list": [{ + "server_id": "192.168.1.1", + "device_id": "0", + "device_ip": "10.30.0.1", + "cluster_id": "0", + }, { + "server_id": "192.168.1.1", + "device_id": "1", + "device_ip": "10.30.0.2", + "cluster_id": "1", + }, { + "server_id": "192.168.1.2", + "device_id": "0", + "device_ip": "10.30.0.3", + "cluster_id": "2", + }, { + "server_id": "192.168.1.2", + "device_id": "1", + "device_ip": "10.30.0.4", + "cluster_id": "3", + }] + } + + def get_device_ip(worker_local_ip, worker_tp_rank, worker_visible_devices): + old_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "") + worker = types.SimpleNamespace() + worker.local_ip = worker_local_ip + worker.tp_rank = worker_tp_rank + worker.llm_datadist_role = LLMRole.PROMPT + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices + agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata( + worker, rank_table) + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = old_visible_devices + return agent_metadata.device_ip + + assert get_device_ip("192.168.1.1", 0, "0") == "10.30.0.1" + assert get_device_ip("192.168.1.1", 0, "1") == "10.30.0.2" + assert get_device_ip("192.168.1.2", 0, "0") == "10.30.0.3" + assert get_device_ip("192.168.1.2", 0, "1") == "10.30.0.4" + assert get_device_ip("192.168.1.1", 0, "0,1") == "10.30.0.1" + assert get_device_ip("192.168.1.1", 1, "0,1") == "10.30.0.2" + assert get_device_ip("192.168.1.1", 0, "") == "10.30.0.1" + assert get_device_ip("192.168.1.1", 1, "") == "10.30.0.2" diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 66fc313..84b2435 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -9,7 +9,7 @@ from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from enum import Enum -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import llm_datadist # type: ignore import msgspec @@ -331,9 +331,7 @@ class LLMDataDistCMgrConnectorWorker(): self.prefill_device_list: list[tuple[int, int]] = [] self.decode_device_list: list[tuple[int, int]] = [] global_rank_table = self.read_offline_rank_table() - self.local_agent_metadata = self.read_agent_metadata( - global_rank_table, self.local_ip, self.local_rank_on_node, - self.llm_datadist_role) + self.local_agent_metadata = self.read_agent_metadata(global_rank_table) self.llm_datadist = LLMDataDist(self.llm_datadist_role, self.local_agent_metadata.cluster_id) self.init_llm_datadist() @@ -448,8 +446,20 @@ class LLMDataDistCMgrConnectorWorker(): # global_rank_table = json.dumps(global_rank_table) return global_rank_table - def read_agent_metadata(self, global_rank_table, server_id, device_rank, - agent_role): + @staticmethod + def _get_visible_devices() -> Callable[[str], bool]: + """ + Return a test function that check if the given device ID is visible. + i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id. + """ + visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "") + if not visible_devices: + return lambda device_id: True + visible_device_list = visible_devices.split(",") + return lambda device_id: device_id in visible_device_list + + def read_agent_metadata(self, global_rank_table): + device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices() devices_type_list = [] agent_metadata = None if self.llm_datadist_role == LLMRole.PROMPT: @@ -462,11 +472,12 @@ class LLMDataDistCMgrConnectorWorker(): for device_type in devices_type_list: device_list = global_rank_table[device_type] device_list = [ - d for d in device_list if d.get("server_id") == server_id + d for d in device_list if d.get("server_id") == self.local_ip + and device_filter(d.get("device_id", "")) ] - if len(device_list) <= device_rank: + if len(device_list) <= self.tp_rank: continue - device_info = device_list[device_rank] + device_info = device_list[self.tp_rank] super_pod_id_ = device_info.get("super_pod_id", None) server_id_ = device_info["server_id"] device_id_ = device_info["device_id"] @@ -481,7 +492,7 @@ class LLMDataDistCMgrConnectorWorker(): super_device_id=super_device_id_, cluster_id=cluster_id_, ) - assert agent_metadata is not None, f"Can't read the target server_id {server_id} and device_rank {device_rank} from rank table" + assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table" return agent_metadata def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):