Files
xc-llm-ascend/tests/e2e/nightly/multi_node/config/multi_node_config.py
Li Wang 4c4a8458a5 [CI] Refator multi-node CI (#3487)
### What this PR does / why we need it?
Refactor the multi-machine CI use case. The purpose of this PR is to
increase the ease of adding multi-machine CI use cases, allowing
developers to add multi-machine cluster model testing use cases
(including PD separation) by simply adding a new YAML configuration
file.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
2025-10-17 09:04:31 +08:00

208 lines
8.1 KiB
Python

import logging
import os
import subprocess
from typing import Optional
import regex as re
import yaml
from tests.e2e.nightly.multi_node.config.utils import (get_avaliable_port,
get_cluster_ips,
get_cur_ip,
get_net_interface,
setup_logger)
setup_logger()
logger = logging.getLogger(__name__)
DISAGGREGATED_PREFILL_PROXY_SCRIPT = "examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py"
class MultiNodeConfig:
def __init__(self,
model: str,
test_name: str,
num_nodes: int = 2,
npu_per_node: int = 16,
server_port: int = 8080,
headless: bool = False,
disaggregated_prefill: Optional[dict] = None,
envs: Optional[dict] = None,
server_cmd: str = "",
perf_cmd: Optional[str] = None,
acc_cmd: Optional[str] = None):
self.test_name = test_name
self.model = model
self.num_nodes = num_nodes
self.npu_per_node = npu_per_node
self.envs = envs if envs is not None else {}
self.server_port = server_port
if disaggregated_prefill:
self.proxy_port = get_avaliable_port()
self.headless = headless
self.server_cmd = server_cmd
self.perf_cmd = perf_cmd
self.acc_cmd = acc_cmd
assert perf_cmd is not None, "perf_cmd must be provided"
assert acc_cmd is not None, "acc_cmd must be provided"
assert server_cmd is not None, "server_cmd must be provided"
self.cur_index = os.getenv("LWS_WORKER_INDEX", 0)
self.cur_ip = get_cur_ip()
self.nic_name = get_net_interface(self.cur_ip)
self.cluster_ips = get_cluster_ips(num_nodes)
self.disaggregated_prefill = disaggregated_prefill
self._init_dist_env()
self.server_cmd = self._expand_env_vars(self.server_cmd, self.envs)
def _init_dist_env(self):
self.envs["HCCL_IF_IP"] = self.cur_ip
self.envs["GLOO_SOCKET_IFNAME"] = self.nic_name
self.envs["TP_SOCKET_IFNAME"] = self.nic_name
self.envs["HCCL_SOCKET_IFNAME"] = self.nic_name
self.envs["LOCAL_IP"] = self.cur_ip
self.envs["NIC_NAME"] = self.nic_name
self.envs["MASTER_IP"] = self.cluster_ips[0]
ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
self.envs[
"LD_LIBRARY_PATH"] = f"{ascend_path}:{self.envs.get('LD_LIBRARY_PATH', os.environ.get('LD_LIBRARY_PATH', ''))}"
# keep the envs keys and values as strings
str_envs = {k: str(v) for k, v in self.envs.items()}
self.envs.clear()
self.envs.update(str_envs)
@staticmethod
def _expand_env_vars(cmd: str, env: dict) -> str:
"""Expand environment variables in the command string."""
cmd = str(cmd)
pattern = re.compile(r"\$(\w+)|\$\{(\w+)\}")
def replace_var(match):
var_name = match.group(1) or match.group(2)
return str(env.get(var_name, match.group(0)))
return pattern.sub(replace_var, cmd)
class _ProxyContext:
def __init__(self, outer, proxy_script):
self.outer = outer
self.proxy_script = proxy_script
self.process = None
def __enter__(self):
o = self.outer
if not o.disaggregated_prefill or not o.is_master:
logger.info(
"Disaggregated prefill not enabled or not master node, skipping proxy launch."
)
return self
prefiller_indices = o.disaggregated_prefill["prefiller_host_index"]
decoder_indices = o.disaggregated_prefill["decoder_host_index"]
common_indices = set(prefiller_indices) & set(decoder_indices)
assert not common_indices, f"Common indices found: {common_indices}"
assert o.proxy_port is not None, "proxy_port must be set"
prefiller_ips = [o.cluster_ips[i] for i in prefiller_indices]
decoder_ips = [o.cluster_ips[i] for i in decoder_indices]
prefiller_ports_list = [str(o.server_port)] * len(prefiller_ips)
decoder_ports_list = [str(o.server_port)] * len(decoder_ips)
proxy_cmd = [
"python",
self.proxy_script,
"--host",
o.cur_ip,
"--port",
str(o.proxy_port),
"--prefiller-hosts",
*prefiller_ips,
"--prefiller-ports",
*prefiller_ports_list,
"--decoder-hosts",
*decoder_ips,
"--decoder-ports",
*decoder_ports_list,
]
env = os.environ.copy()
env.update(o.envs)
logger.info(f"Launching proxy: {' '.join(proxy_cmd)}")
self.process = subprocess.Popen(proxy_cmd, env=env)
o.proxy_process = self.process
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.process:
logger.info("Terminating proxy server process...")
try:
self.process.terminate()
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning(
"Proxy process did not terminate, killing it...")
self.process.kill()
logger.info("Proxy server process terminated.")
def launch_server_proxy(self, proxy_script: str):
"""Return a context manager that launches the proxy server if disaggregated prefill is enabled."""
return self._ProxyContext(self, proxy_script)
@classmethod
def from_yaml(cls, yaml_path: Optional[str] = None):
if not yaml_path:
yaml_path = os.getenv(
"CONFIG_YAML_PATH",
"tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml")
with open(yaml_path, 'r') as file:
config_data = yaml.safe_load(file)
test_name = config_data.get("test_name", "default_test")
model = config_data.get("model", "default_model")
envs = config_data.get("env_common", {})
num_nodes = config_data.get("num_nodes", 2)
npu_per_node = config_data.get("npu_per_node", 16)
disaggregated_prefill = config_data.get("disaggregated_prefill")
# If disaggregated_prefill is set, override server_port to an available port for proxy running
server_port = config_data.get("server_port", 8080)
deployments = config_data.get("deployment", [])
assert len(deployments) == num_nodes, \
f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})"
for deployment in deployments:
if deployment.get("local_index") == int(
os.getenv("LWS_WORKER_INDEX", 0)):
envs_extend = deployment.get("env_extend", {})
if envs_extend:
envs.update(envs_extend)
server_cmd = deployment.get("server_cmd")
headless = deployment.get("headless", False)
break
benchmarks = config_data.get("benchmarks", {})
assert benchmarks is not None, "benchmarks must be provided"
perf_cmd = benchmarks["perf"]
acc_cmd = benchmarks["acc"]
return cls(model=model,
test_name=test_name,
num_nodes=num_nodes,
npu_per_node=npu_per_node,
envs=envs,
server_port=server_port,
headless=headless,
disaggregated_prefill=disaggregated_prefill,
server_cmd=server_cmd,
perf_cmd=perf_cmd,
acc_cmd=acc_cmd)
@property
def world_size(self):
return self.num_nodes * self.npu_per_node
@property
def is_master(self):
return int(self.cur_index) == 0