diff --git a/tests/e2e/nightly/multi_node/config/utils.py b/tests/e2e/nightly/multi_node/config/utils.py index da18af26..1b8f9e5f 100644 --- a/tests/e2e/nightly/multi_node/config/utils.py +++ b/tests/e2e/nightly/multi_node/config/utils.py @@ -1,13 +1,12 @@ import logging import os import socket +import time from contextlib import contextmanager from typing import Optional import psutil -# import torch.distributed as dist - @contextmanager def temp_env(env_dict): @@ -25,33 +24,35 @@ def temp_env(env_dict): os.environ[k] = v -# @contextmanager -# def dist_group(backend="gloo"): -# if dist.is_initialized(): -# yield -# return +def dns_resolver(retries: int = 20, base_delay: float = 0.5): + # We should resolve DNS with retries to avoid transient network issues. + # When the pod is just started, DNS resolution may fail. + def resolve(dns: str): + delay = base_delay + for attempt in range(retries): + try: + return socket.gethostbyname(dns) + except socket.gaierror: + if attempt == retries - 1: + raise + time.sleep(delay) + delay = min(delay * 1.5, 5) -# dist.init_process_group(backend=backend) -# try: -# yield -# finally: -# dist.destroy_process_group() + return resolve -def get_cluster_ips(word_size: int = 2) -> list[str]: - """ - Returns the IP addresses of all nodes in the cluster. - 0: leader - 1~N-1: workers - """ +def get_cluster_dns_list(word_size: int) -> list[str]: leader_dns = os.getenv("LWS_LEADER_ADDRESS") if not leader_dns: raise RuntimeError("LWS_LEADER_ADDRESS is not set") - cluster_dns = [leader_dns] - for i in range(1, word_size): - cur_dns = f"vllm-0-{i}.vllm.vllm-project" - cluster_dns.append(cur_dns) - return [socket.gethostbyname(dns) for dns in cluster_dns] + + workers = [f"vllm-0-{i}.vllm.vllm-project" for i in range(1, word_size)] + return [leader_dns] + workers + + +def get_cluster_ips(word_size: int = 2) -> list[str]: + resolver = dns_resolver() + return [resolver(dns) for dns in get_cluster_dns_list(word_size)] def get_avaliable_port(start_port: int = 6000, end_port: int = 7000) -> int: @@ -66,9 +67,29 @@ def get_avaliable_port(start_port: int = 6000, end_port: int = 7000) -> int: raise RuntimeError("No available port found") -def get_cur_ip() -> str: - """Returns the current machine's IP address.""" - return socket.gethostbyname_ex(socket.gethostname())[2][0] +def get_cur_ip(retries: int = 20, base_delay: float = 0.5): + """ + Returns the pod/machine's primary IP address with retry. + This is necessary because network interfaces may not be ready + immediately after container startup. + """ + delay = base_delay + + for attempt in range(retries): + try: + # Best method: UDP trick (doesn't actually send packets) + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception: + # fallback: hostname resolution + try: + return socket.gethostbyname(socket.gethostname()) + except Exception: + if attempt == retries - 1: + raise RuntimeError("Failed to determine local IP address") + time.sleep(delay) + delay = min(delay * 1.5, 5) def get_net_interface(ip: Optional[str] = None) -> Optional[str]: