diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml index 52ce615c..b9a584ed 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml @@ -20,6 +20,9 @@ env_common: HCCL_BUFFSIZE: 1024 SERVER_PORT: 8080 NUMEXPR_MAX_THREADS: 128 + DISAGGREGATED_PREFILL_PROXY_SCRIPT: "examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py" +# For None kubernetes deployment, list the IPs of all nodes used in order as follow +# cluster_hosts: [] disaggregated_prefill: enabled: true prefiller_host_index: [0] diff --git a/tests/e2e/nightly/multi_node/config/multi_node_config.py b/tests/e2e/nightly/multi_node/config/multi_node_config.py index 9bde4581..56453cad 100644 --- a/tests/e2e/nightly/multi_node/config/multi_node_config.py +++ b/tests/e2e/nightly/multi_node/config/multi_node_config.py @@ -7,15 +7,14 @@ from typing import Optional import regex as re import yaml -from tests.e2e.nightly.multi_node.config.utils import (get_avaliable_port, +from tests.e2e.nightly.multi_node.config.utils import (get_all_ipv4, + get_avaliable_port, get_cluster_ips, - get_cur_ip, get_net_interface, setup_logger) setup_logger() logger = logging.getLogger(__name__) -DISAGGREGATED_PREFILL_PROXY_SCRIPT = "examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py" DISAGGEGATED_PREFILL_PORT = 5333 CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/models/" @@ -28,22 +27,33 @@ class NodeInfo: headless: bool server_port: int + def __str__(self): + return (f"NodeInfo:\n" + f" index={self.index}\n" + f" ip={self.ip}\n" + f" server_port={self.server_port}\n" + f" headless={self.headless}") + class MultiNodeConfig: def __init__(self, model: str, test_name: str, + nodes_info: list[NodeInfo], npu_per_node: int = 16, server_port: int = 8080, disaggregated_prefill: Optional[dict] = None, envs: Optional[dict] = None, - nodes_info: Optional[list[NodeInfo]] = None, perf_cmd: Optional[str] = None, acc_cmd: Optional[str] = None): self.test_name = test_name self.model = model - self.nodes_info = nodes_info or [] + self.nodes_info = nodes_info + # We assume the first index of nodes as the master + # NOTE: this may be different in the scenarios like disaggregated prefill + # There may be multi groups of nodes, and the master of each group may be different + self.master_ip = self.nodes_info[0].ip self.num_nodes = len(self.nodes_info) self.npu_per_node = npu_per_node self.server_port = server_port @@ -52,18 +62,48 @@ class MultiNodeConfig: self.perf_cmd = perf_cmd self.acc_cmd = acc_cmd - self.cur_index = int(os.getenv("LWS_WORKER_INDEX", 0)) - self.cur_ip = get_cur_ip() - self.nic_name = get_net_interface(self.cur_ip) - self.cluster_ips = get_cluster_ips(self.num_nodes) - self.cur_node_info: NodeInfo = self.nodes_info[self.cur_index] self.disaggregated_prefill = disaggregated_prefill self._init_disaggregated_prefill() self._init_dist_env() - self.server_cmd = self._expand_env_vars(self.cur_node_info.server_cmd, + self.server_cmd = self._expand_env_vars(self.node_info.server_cmd, self.envs) + @property + def cur_ip(self): + return self.nodes_info[self.cur_index].ip + + @property + def nic_name(self): + return get_net_interface(self.cur_ip) + + @property + def node_info(self): + return self.nodes_info[self.cur_index] + + @property + def cur_index(self): + # 1. Try to read worker index from K8s environment variable + worker_index = os.environ.get("LWS_WORKER_INDEX") + if worker_index: + return int(worker_index) + + # 2. Fallback: match local IP against cluster IP list + cluster_ips = [node.ip for node in self.nodes_info] + cluster_ip_set = set(cluster_ips) + + cur_ips = get_all_ipv4() + + for ip in cur_ips: + if ip in cluster_ip_set: + return cluster_ips.index(ip) + + raise RuntimeError( + "Could not determine current node index: no matching IP.\n" + f"Local machine IPs: {cur_ips}\n" + f"Cluster IPs: {cluster_ips}\n" + "Please check your config file or network settings.") + def _init_disaggregated_prefill(self): if self.disaggregated_prefill: decode_host_index = self.disaggregated_prefill.get( @@ -85,15 +125,17 @@ class MultiNodeConfig: self.envs["LOCAL_IP"] = self.cur_ip self.envs["NIC_NAME"] = self.nic_name - master_ip = self.cluster_ips[0] + master_ip = self.master_ip if self.disaggregated_prefill: self.envs[ "DISAGGREGATED_PREFILL_RANK_TABLE_PATH"] = self.disaggregated_prefill.get( "ranktable_path") if self.cur_index < self.decode_start_index: - master_ip = self.cluster_ips[0] + # For prefiller nodes, use the default master ip(index==0) as DP master + master_ip = self.master_ip else: - master_ip = self.cluster_ips[self.decode_start_index] + # For decoder nodes, use the first decoder node as DP master + master_ip = self.nodes_info[self.decode_start_index].ip self.envs["MASTER_IP"] = master_ip ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages" @@ -139,8 +181,9 @@ class MultiNodeConfig: assert not common_indices, f"Common indices found: {common_indices}" assert o.proxy_port is not None, "proxy_port must be set" - prefiller_ips = [o.cluster_ips[i] for i in prefiller_indices] - decoder_ips = [o.cluster_ips[i] for i in decoder_indices] + cluster_ips = [node.ip for node in o.nodes_info] + prefiller_ips = [cluster_ips[i] for i in prefiller_indices] + decoder_ips = [cluster_ips[i] for i in decoder_indices] prefiller_ports_list = [str(o.server_port)] * len(prefiller_ips) decoder_ports_list = [str(o.server_port)] * len(decoder_ips) @@ -204,8 +247,13 @@ class MultiNodeConfig: deployments = config_data.get("deployment", []) assert len(deployments) == num_nodes, \ f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})" - - cluster_ips = get_cluster_ips(num_nodes) + cluster_ips = config_data.get("cluster_hosts", None) + if cluster_ips: + assert len(cluster_ips) == num_nodes, \ + "Must provide cluster_ips for all nodes if it is explicitly specified." + else: + logger.info("Resolving cluster IPs via DNS...") + cluster_ips = get_cluster_ips(num_nodes) nodes_info = [] for index, deployment in enumerate(deployments): @@ -243,7 +291,7 @@ class MultiNodeConfig: return self.cur_index == 0 def _gen_ranktable(self): - cluster_ip = self.cluster_ips + cluster_ip = [nodes.ip for nodes in self.nodes_info] assert len(cluster_ip) > 0 nnodes = self.num_nodes node_rank = self.cur_index diff --git a/tests/e2e/nightly/multi_node/config/utils.py b/tests/e2e/nightly/multi_node/config/utils.py index 1b8f9e5f..95fcad5b 100644 --- a/tests/e2e/nightly/multi_node/config/utils.py +++ b/tests/e2e/nightly/multi_node/config/utils.py @@ -107,6 +107,19 @@ def get_net_interface(ip: Optional[str] = None) -> Optional[str]: return None +def get_all_ipv4(): + """get all the ipv4 address for current node""" + ipv4s = set() + hostname = socket.gethostname() + + for info in socket.getaddrinfo(hostname, None, family=socket.AF_INET): + ipv4s.add(info[4][0]) + + ipv4s.add("127.0.0.1") + + return list(ipv4s) + + def setup_logger(): """Setup logging configuration.""" logging.basicConfig( diff --git a/tests/e2e/nightly/multi_node/test_multi_node.py b/tests/e2e/nightly/multi_node/test_multi_node.py index 2b23e755..212ad26d 100644 --- a/tests/e2e/nightly/multi_node/test_multi_node.py +++ b/tests/e2e/nightly/multi_node/test_multi_node.py @@ -7,8 +7,8 @@ from modelscope import snapshot_download # type: ignore from requests.exceptions import ConnectionError, HTTPError, Timeout from tests.e2e.conftest import RemoteOpenAIServer -from tests.e2e.nightly.multi_node.config.multi_node_config import ( - DISAGGREGATED_PREFILL_PROXY_SCRIPT, MultiNodeConfig) +from tests.e2e.nightly.multi_node.config.multi_node_config import \ + MultiNodeConfig from tools.aisbench import run_aisbench_cases prompts = [ @@ -100,8 +100,10 @@ async def test_multi_node() -> None: disaggregated_prefill = config.disaggregated_prefill server_port = config.server_port proxy_port = config.proxy_port - server_host = config.cluster_ips[0] - with config.launch_server_proxy(DISAGGREGATED_PREFILL_PROXY_SCRIPT): + server_host = config.node_info.ip + proxy_script = config.envs.get("DISAGGREGATED_PREFILL_PROXY_SCRIPT", \ + 'examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py') + with config.launch_server_proxy(proxy_script): with RemoteOpenAIServer( model=local_model_path, vllm_serve_args=config.server_cmd,