diff --git a/examples/disaggregated_prefill/find_device_ips.py b/examples/disaggregated_prefill/find_device_ips.py index 205afbf..48fd7b9 100644 --- a/examples/disaggregated_prefill/find_device_ips.py +++ b/examples/disaggregated_prefill/find_device_ips.py @@ -30,38 +30,40 @@ import vllm_ascend.envs as envs HCCN_TOOL_PATH = envs.HCCN_PATH -def get_device_ips(world_size: int): - npu_info = subprocess.run( - ["npu-smi", "info", "-m"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - ) +def get_device_ips(): + npu_info = subprocess.run(['npu-smi', 'info', '-m'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True) if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH): raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.") - npu_start_idx = int( - re.match(r".*\n\t([0-9]+).*", - npu_info.stdout).group(1)) # type: ignore + + # ‌Extract NPU IDs for all Ascend devices (excluding Mcu rows) + device_ids = [] + for line in npu_info.stdout.strip().split('\n'): + match = re.match(r'^\s*(\d+)\s+\d+\s+\d+\s+Ascend', line) + if match: + device_ids.append(int(match.group(1))) + + if not device_ids: + raise RuntimeError( + "Cannot parse any valid device ID from npu-smi output.") + device_ip_list = [] - for ip_offset in range(world_size): - cmd = [ - HCCN_TOOL_PATH, - "-i", - f"{npu_start_idx + ip_offset}", - "-ip", - "-g", - ] - device_ip_info = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - ) - device_ip = re.match(r"ipaddr:(.*)\n", - device_ip_info.stdout).group(1) # type: ignore + for device_id in device_ids: + cmd = [HCCN_TOOL_PATH, '-i', str(device_id), '-ip', '-g'] + device_ip_info = subprocess.run(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True) + ip_match = re.search(r'ipaddr:(.*)', device_ip_info.stdout) + if not ip_match: + raise RuntimeError( + f"Cannot parse IP from hccn_tool for device {device_id}") + device_ip = ip_match.group(1).strip() device_ip_list.append(device_ip) + return device_ip_list -# Pass number of NPUs into this function. -print(get_device_ips(8)) +print(get_device_ips())