Files
xc-llm-ascend/tests/e2e/nightly/multi_node/config/multi_node_config.py

286 lines
11 KiB
Python
Raw Normal View History

import logging
import os
import subprocess
from dataclasses import dataclass
from typing import Optional
import regex as re
import yaml
from tests.e2e.nightly.multi_node.config.utils import (get_all_ipv4,
get_avaliable_port,
get_cluster_ips,
get_net_interface,
setup_logger)
setup_logger()
logger = logging.getLogger(__name__)
DISAGGEGATED_PREFILL_PORT = 5333
CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/models/"
@dataclass
class NodeInfo:
index: int
ip: str
server_cmd: str
headless: bool
server_port: int
def __str__(self):
return (f"NodeInfo:\n"
f" index={self.index}\n"
f" ip={self.ip}\n"
f" headless={self.headless}\n"
f" server_port={self.server_port}")
class MultiNodeConfig:
def __init__(self,
model: str,
test_name: str,
nodes_info: list[NodeInfo],
npu_per_node: int = 16,
server_port: int = 8080,
disaggregated_prefill: Optional[dict] = None,
envs: Optional[dict] = None,
perf_cmd: Optional[str] = None,
acc_cmd: Optional[str] = None):
self.test_name = test_name
self.model = model
self.nodes_info = nodes_info
# We assume the first index of nodes as the master
# NOTE: this may be different in the scenarios like disaggregated prefill
# There may be multi groups of nodes, and the master of each group may be different
self.master_ip = self.nodes_info[0].ip
self.num_nodes = len(self.nodes_info)
self.npu_per_node = npu_per_node
self.server_port = server_port
self.envs = envs if envs is not None else {}
self.proxy_port = get_avaliable_port()
self.perf_cmd = perf_cmd
self.acc_cmd = acc_cmd
self.disaggregated_prefill = disaggregated_prefill
self._init_disaggregated_prefill()
self._init_dist_env()
self.server_cmd = self._expand_env_vars(self.node_info.server_cmd,
self.envs)
@property
def cur_ip(self):
return self.nodes_info[self.cur_index].ip
@property
def nic_name(self):
return get_net_interface(self.cur_ip)
@property
def node_info(self):
return self.nodes_info[self.cur_index]
@property
def cur_index(self):
# 1. Try to read worker index from K8s environment variable
worker_index = os.environ.get("LWS_WORKER_INDEX")
if worker_index:
return int(worker_index)
# 2. Fallback: match local IP against cluster IP list
cluster_ips = [node.ip for node in self.nodes_info]
cluster_ip_set = set(cluster_ips)
cur_ips = get_all_ipv4()
for ip in cur_ips:
if ip in cluster_ip_set:
return cluster_ips.index(ip)
raise RuntimeError(
"Could not determine current node index: no matching IP.\n"
f"Local machine IPs: {cur_ips}\n"
f"Cluster IPs: {cluster_ips}\n"
"Please check your config file or network settings.")
def _init_disaggregated_prefill(self):
if self.disaggregated_prefill:
decode_host_index = self.disaggregated_prefill.get(
"decoder_host_index")
if not decode_host_index:
raise RuntimeError("got empty decode_host_index")
self.decode_start_index: int = decode_host_index[0]
self.num_prefillers = self.decode_start_index
self.num_decoders = self.num_nodes - self.num_prefillers
def _init_dist_env(self):
self.envs["HCCL_IF_IP"] = self.cur_ip
self.envs["GLOO_SOCKET_IFNAME"] = self.nic_name
self.envs["TP_SOCKET_IFNAME"] = self.nic_name
self.envs["HCCL_SOCKET_IFNAME"] = self.nic_name
self.envs["LOCAL_IP"] = self.cur_ip
self.envs["NIC_NAME"] = self.nic_name
master_ip = self.master_ip
if self.disaggregated_prefill:
if self.cur_index < self.decode_start_index:
# For prefiller nodes, use the default master ip(index==0) as DP master
master_ip = self.master_ip
else:
# For decoder nodes, use the first decoder node as DP master
master_ip = self.nodes_info[self.decode_start_index].ip
self.envs["MASTER_IP"] = master_ip
ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
self.envs[
"LD_LIBRARY_PATH"] = f"{ascend_path}:{self.envs.get('LD_LIBRARY_PATH', os.environ.get('LD_LIBRARY_PATH', ''))}"
# keep the envs keys and values as strings
str_envs = {k: str(v) for k, v in self.envs.items()}
self.envs.clear()
self.envs.update(str_envs)
@staticmethod
def _expand_env_vars(cmd: str, env: dict) -> str:
"""Expand environment variables in the command string."""
cmd = str(cmd)
pattern = re.compile(r"\$(\w+)|\$\{(\w+)\}")
def replace_var(match):
var_name = match.group(1) or match.group(2)
return str(env.get(var_name, match.group(0)))
return pattern.sub(replace_var, cmd)
class _ProxyContext:
def __init__(self, outer, proxy_script):
self.outer = outer
self.proxy_script = proxy_script
self.process = None
def __enter__(self):
o = self.outer
if not o.disaggregated_prefill or not o.is_master:
logger.info(
"Disaggregated prefill not enabled or not master node, skipping proxy launch."
)
return self
prefiller_indices = o.disaggregated_prefill["prefiller_host_index"]
decoder_indices = o.disaggregated_prefill["decoder_host_index"]
common_indices = set(prefiller_indices) & set(decoder_indices)
assert not common_indices, f"Common indices found: {common_indices}"
assert o.proxy_port is not None, "proxy_port must be set"
cluster_ips = [node.ip for node in o.nodes_info]
prefiller_ips = [cluster_ips[i] for i in prefiller_indices]
decoder_ips = [cluster_ips[i] for i in decoder_indices]
prefiller_ports_list = [str(o.server_port)] * len(prefiller_ips)
decoder_ports_list = [str(o.server_port)] * len(decoder_ips)
proxy_cmd = [
"python",
self.proxy_script,
"--host",
o.cur_ip,
"--port",
str(o.proxy_port),
"--prefiller-hosts",
*prefiller_ips,
"--prefiller-ports",
*prefiller_ports_list,
"--decoder-hosts",
*decoder_ips,
"--decoder-ports",
*decoder_ports_list,
]
env = os.environ.copy()
env.update(o.envs)
logger.info(f"Launching proxy: {' '.join(proxy_cmd)}")
self.process = subprocess.Popen(proxy_cmd, env=env)
o.proxy_process = self.process
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.process:
logger.info("Terminating proxy server process...")
try:
self.process.terminate()
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning(
"Proxy process did not terminate, killing it...")
self.process.kill()
logger.info("Proxy server process terminated.")
def launch_server_proxy(self, proxy_script: str):
"""Return a context manager that launches the proxy server if disaggregated prefill is enabled."""
return self._ProxyContext(self, proxy_script)
@classmethod
def from_yaml(cls, yaml_path: Optional[str] = None):
if not yaml_path:
yaml_path = os.getenv("CONFIG_YAML_PATH", "DeepSeek-V3.yaml")
yaml_path = os.path.join(CONFIG_BASE_PATH, yaml_path)
with open(yaml_path, 'r') as file:
config_data = yaml.safe_load(file)
test_name = config_data.get("test_name", "default_test")
model = config_data.get("model", "default_model")
envs = config_data.get("env_common", {})
num_nodes = config_data.get("num_nodes", 2)
npu_per_node = config_data.get("npu_per_node", 16)
disaggregated_prefill = config_data.get("disaggregated_prefill")
# If disaggregated_prefill is set, override server_port to an available port for proxy running
server_port = config_data.get("server_port", 8080)
deployments = config_data.get("deployment", [])
assert len(deployments) == num_nodes, \
f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})"
cluster_ips = config_data.get("cluster_hosts", None)
if cluster_ips:
assert len(cluster_ips) == num_nodes, \
"Must provide cluster_ips for all nodes if it is explicitly specified."
else:
logger.info("Resolving cluster IPs via DNS...")
cluster_ips = get_cluster_ips(num_nodes)
nodes_info = []
for index, deployment in enumerate(deployments):
# after assert len(deployments) == num_nodes, we can assume that this will must have a match
server_cmd = deployment.get("server_cmd", "")
headless = "--headless" in server_cmd
nodes_info.append(
NodeInfo(ip=cluster_ips[index],
index=index,
headless=headless,
server_port=server_port,
server_cmd=server_cmd))
benchmarks = config_data.get("benchmarks") or {}
assert benchmarks is not None, "benchmarks must be provided"
perf_cmd = benchmarks.get("perf")
acc_cmd = benchmarks.get("acc")
return cls(model=model,
test_name=test_name,
npu_per_node=npu_per_node,
envs=envs,
server_port=server_port,
disaggregated_prefill=disaggregated_prefill,
nodes_info=nodes_info,
perf_cmd=perf_cmd,
acc_cmd=acc_cmd)
@property
def world_size(self):
return self.num_nodes * self.npu_per_node
@property
def is_master(self):
return self.cur_index == 0