### What this PR does / why we need it? 1. Optimize multi-node waiting logic 2. Remove the `tee` pipeline for logs, which will lead to hang issue ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 Signed-off-by: wangli <wangli858794774@gmail.com>
343 lines
13 KiB
Python
343 lines
13 KiB
Python
import logging
|
|
import os
|
|
import subprocess
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import regex as re
|
|
import yaml
|
|
|
|
from tests.e2e.nightly.multi_node.config.utils import (get_all_ipv4,
|
|
get_avaliable_port,
|
|
get_cluster_ips,
|
|
get_net_interface,
|
|
setup_logger)
|
|
|
|
setup_logger()
|
|
logger = logging.getLogger(__name__)
|
|
DISAGGEGATED_PREFILL_PORT = 5333
|
|
CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/models/"
|
|
|
|
|
|
@dataclass
|
|
class NodeInfo:
|
|
index: int
|
|
ip: str
|
|
server_cmd: str
|
|
headless: bool
|
|
server_port: int
|
|
|
|
def __str__(self):
|
|
return (f"NodeInfo:\n"
|
|
f" index={self.index}\n"
|
|
f" ip={self.ip}\n"
|
|
f" headless={self.headless}\n"
|
|
f" server_port={self.server_port}")
|
|
|
|
|
|
class MultiNodeConfig:
|
|
|
|
def __init__(self,
|
|
model: str,
|
|
test_name: str,
|
|
nodes_info: list[NodeInfo],
|
|
npu_per_node: int = 16,
|
|
server_port: int = 8080,
|
|
disaggregated_prefill: Optional[dict] = None,
|
|
envs: Optional[dict] = None,
|
|
perf_cmd: Optional[str] = None,
|
|
acc_cmd: Optional[str] = None):
|
|
self.test_name = test_name
|
|
self.model = model
|
|
self.nodes_info = nodes_info
|
|
# We assume the first index of nodes as the master
|
|
# NOTE: this may be different in the scenarios like disaggregated prefill
|
|
# There may be multi groups of nodes, and the master of each group may be different
|
|
self.master_ip = self.nodes_info[0].ip
|
|
self.num_nodes = len(self.nodes_info)
|
|
self.npu_per_node = npu_per_node
|
|
self.server_port = server_port
|
|
self.envs = envs if envs is not None else {}
|
|
self.proxy_port = get_avaliable_port()
|
|
self.perf_cmd = perf_cmd
|
|
self.acc_cmd = acc_cmd
|
|
|
|
self.disaggregated_prefill = disaggregated_prefill
|
|
self._init_disaggregated_prefill()
|
|
|
|
self._init_dist_env()
|
|
self.server_cmd = self._expand_env_vars(self.node_info.server_cmd,
|
|
self.envs)
|
|
|
|
@property
|
|
def cur_ip(self):
|
|
return self.nodes_info[self.cur_index].ip
|
|
|
|
@property
|
|
def nic_name(self):
|
|
return get_net_interface(self.cur_ip)
|
|
|
|
@property
|
|
def node_info(self):
|
|
return self.nodes_info[self.cur_index]
|
|
|
|
@property
|
|
def cur_index(self):
|
|
# 1. Try to read worker index from K8s environment variable
|
|
worker_index = os.environ.get("LWS_WORKER_INDEX")
|
|
if worker_index:
|
|
return int(worker_index)
|
|
|
|
# 2. Fallback: match local IP against cluster IP list
|
|
cluster_ips = [node.ip for node in self.nodes_info]
|
|
cluster_ip_set = set(cluster_ips)
|
|
|
|
cur_ips = get_all_ipv4()
|
|
|
|
for ip in cur_ips:
|
|
if ip in cluster_ip_set:
|
|
return cluster_ips.index(ip)
|
|
|
|
raise RuntimeError(
|
|
"Could not determine current node index: no matching IP.\n"
|
|
f"Local machine IPs: {cur_ips}\n"
|
|
f"Cluster IPs: {cluster_ips}\n"
|
|
"Please check your config file or network settings.")
|
|
|
|
def _init_disaggregated_prefill(self):
|
|
if self.disaggregated_prefill:
|
|
decode_host_index = self.disaggregated_prefill.get(
|
|
"decoder_host_index")
|
|
if not decode_host_index:
|
|
raise RuntimeError("got empty decode_host_index")
|
|
self.decode_start_index: int = decode_host_index[0]
|
|
self.num_prefillers = self.decode_start_index
|
|
self.num_decoders = self.num_nodes - self.num_prefillers
|
|
if self.disaggregated_prefill.get(
|
|
"ranktable_gen_path") is not None:
|
|
self._gen_ranktable()
|
|
|
|
def _init_dist_env(self):
|
|
self.envs["HCCL_IF_IP"] = self.cur_ip
|
|
self.envs["GLOO_SOCKET_IFNAME"] = self.nic_name
|
|
self.envs["TP_SOCKET_IFNAME"] = self.nic_name
|
|
self.envs["HCCL_SOCKET_IFNAME"] = self.nic_name
|
|
self.envs["LOCAL_IP"] = self.cur_ip
|
|
self.envs["NIC_NAME"] = self.nic_name
|
|
|
|
master_ip = self.master_ip
|
|
if self.disaggregated_prefill:
|
|
self.envs[
|
|
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH"] = self.disaggregated_prefill.get(
|
|
"ranktable_path")
|
|
if self.cur_index < self.decode_start_index:
|
|
# For prefiller nodes, use the default master ip(index==0) as DP master
|
|
master_ip = self.master_ip
|
|
else:
|
|
# For decoder nodes, use the first decoder node as DP master
|
|
master_ip = self.nodes_info[self.decode_start_index].ip
|
|
|
|
self.envs["MASTER_IP"] = master_ip
|
|
ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
|
|
self.envs[
|
|
"LD_LIBRARY_PATH"] = f"{ascend_path}:{self.envs.get('LD_LIBRARY_PATH', os.environ.get('LD_LIBRARY_PATH', ''))}"
|
|
|
|
# keep the envs keys and values as strings
|
|
str_envs = {k: str(v) for k, v in self.envs.items()}
|
|
self.envs.clear()
|
|
self.envs.update(str_envs)
|
|
|
|
@staticmethod
|
|
def _expand_env_vars(cmd: str, env: dict) -> str:
|
|
"""Expand environment variables in the command string."""
|
|
cmd = str(cmd)
|
|
pattern = re.compile(r"\$(\w+)|\$\{(\w+)\}")
|
|
|
|
def replace_var(match):
|
|
var_name = match.group(1) or match.group(2)
|
|
return str(env.get(var_name, match.group(0)))
|
|
|
|
return pattern.sub(replace_var, cmd)
|
|
|
|
class _ProxyContext:
|
|
|
|
def __init__(self, outer, proxy_script):
|
|
self.outer = outer
|
|
self.proxy_script = proxy_script
|
|
self.process = None
|
|
|
|
def __enter__(self):
|
|
o = self.outer
|
|
if not o.disaggregated_prefill or not o.is_master:
|
|
logger.info(
|
|
"Disaggregated prefill not enabled or not master node, skipping proxy launch."
|
|
)
|
|
return self
|
|
|
|
prefiller_indices = o.disaggregated_prefill["prefiller_host_index"]
|
|
decoder_indices = o.disaggregated_prefill["decoder_host_index"]
|
|
|
|
common_indices = set(prefiller_indices) & set(decoder_indices)
|
|
assert not common_indices, f"Common indices found: {common_indices}"
|
|
assert o.proxy_port is not None, "proxy_port must be set"
|
|
|
|
cluster_ips = [node.ip for node in o.nodes_info]
|
|
prefiller_ips = [cluster_ips[i] for i in prefiller_indices]
|
|
decoder_ips = [cluster_ips[i] for i in decoder_indices]
|
|
prefiller_ports_list = [str(o.server_port)] * len(prefiller_ips)
|
|
decoder_ports_list = [str(o.server_port)] * len(decoder_ips)
|
|
|
|
proxy_cmd = [
|
|
"python",
|
|
self.proxy_script,
|
|
"--host",
|
|
o.cur_ip,
|
|
"--port",
|
|
str(o.proxy_port),
|
|
"--prefiller-hosts",
|
|
*prefiller_ips,
|
|
"--prefiller-ports",
|
|
*prefiller_ports_list,
|
|
"--decoder-hosts",
|
|
*decoder_ips,
|
|
"--decoder-ports",
|
|
*decoder_ports_list,
|
|
]
|
|
|
|
env = os.environ.copy()
|
|
env.update(o.envs)
|
|
logger.info(f"Launching proxy: {' '.join(proxy_cmd)}")
|
|
|
|
self.process = subprocess.Popen(proxy_cmd, env=env)
|
|
o.proxy_process = self.process
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.process:
|
|
logger.info("Terminating proxy server process...")
|
|
try:
|
|
self.process.terminate()
|
|
self.process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
logger.warning(
|
|
"Proxy process did not terminate, killing it...")
|
|
self.process.kill()
|
|
logger.info("Proxy server process terminated.")
|
|
|
|
def launch_server_proxy(self, proxy_script: str):
|
|
"""Return a context manager that launches the proxy server if disaggregated prefill is enabled."""
|
|
return self._ProxyContext(self, proxy_script)
|
|
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: Optional[str] = None):
|
|
if not yaml_path:
|
|
yaml_path = os.getenv("CONFIG_YAML_PATH", "DeepSeek-V3.yaml")
|
|
yaml_path = os.path.join(CONFIG_BASE_PATH, yaml_path)
|
|
with open(yaml_path, 'r') as file:
|
|
config_data = yaml.safe_load(file)
|
|
test_name = config_data.get("test_name", "default_test")
|
|
model = config_data.get("model", "default_model")
|
|
envs = config_data.get("env_common", {})
|
|
num_nodes = config_data.get("num_nodes", 2)
|
|
npu_per_node = config_data.get("npu_per_node", 16)
|
|
disaggregated_prefill = config_data.get("disaggregated_prefill")
|
|
# If disaggregated_prefill is set, override server_port to an available port for proxy running
|
|
server_port = config_data.get("server_port", 8080)
|
|
|
|
deployments = config_data.get("deployment", [])
|
|
assert len(deployments) == num_nodes, \
|
|
f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})"
|
|
cluster_ips = config_data.get("cluster_hosts", None)
|
|
if cluster_ips:
|
|
assert len(cluster_ips) == num_nodes, \
|
|
"Must provide cluster_ips for all nodes if it is explicitly specified."
|
|
else:
|
|
logger.info("Resolving cluster IPs via DNS...")
|
|
cluster_ips = get_cluster_ips(num_nodes)
|
|
nodes_info = []
|
|
|
|
for index, deployment in enumerate(deployments):
|
|
# after assert len(deployments) == num_nodes, we can assume that this will must have a match
|
|
server_cmd = deployment.get("server_cmd", "")
|
|
headless = "--headless" in server_cmd
|
|
nodes_info.append(
|
|
NodeInfo(ip=cluster_ips[index],
|
|
index=index,
|
|
headless=headless,
|
|
server_port=server_port,
|
|
server_cmd=server_cmd))
|
|
|
|
benchmarks = config_data.get("benchmarks") or {}
|
|
assert benchmarks is not None, "benchmarks must be provided"
|
|
perf_cmd = benchmarks.get("perf")
|
|
acc_cmd = benchmarks.get("acc")
|
|
|
|
return cls(model=model,
|
|
test_name=test_name,
|
|
npu_per_node=npu_per_node,
|
|
envs=envs,
|
|
server_port=server_port,
|
|
disaggregated_prefill=disaggregated_prefill,
|
|
nodes_info=nodes_info,
|
|
perf_cmd=perf_cmd,
|
|
acc_cmd=acc_cmd)
|
|
|
|
@property
|
|
def world_size(self):
|
|
return self.num_nodes * self.npu_per_node
|
|
|
|
@property
|
|
def is_master(self):
|
|
return self.cur_index == 0
|
|
|
|
def _gen_ranktable(self):
|
|
cluster_ip = [nodes.ip for nodes in self.nodes_info]
|
|
assert len(cluster_ip) > 0
|
|
nnodes = self.num_nodes
|
|
node_rank = self.cur_index
|
|
master_addr = cluster_ip[0]
|
|
master_port = DISAGGEGATED_PREFILL_PORT
|
|
assert self.disaggregated_prefill is not None
|
|
ranktable_gen_path = self.disaggregated_prefill.get(
|
|
"ranktable_gen_path")
|
|
ranktable_path = self.disaggregated_prefill.get("ranktable_path")
|
|
assert ranktable_gen_path is not None and ranktable_path is not None
|
|
if os.path.exists(str(ranktable_path)):
|
|
logger.info("ranktable has already generated")
|
|
return
|
|
|
|
local_host = self.cur_ip
|
|
|
|
cmd = [
|
|
"torchrun",
|
|
"--nproc_per_node",
|
|
"1",
|
|
"--nnodes",
|
|
str(nnodes),
|
|
"--node_rank",
|
|
str(node_rank),
|
|
"--master_addr",
|
|
master_addr,
|
|
"--master_port",
|
|
str(master_port),
|
|
ranktable_gen_path,
|
|
"--ranktable-path",
|
|
str(ranktable_path),
|
|
"--local-host",
|
|
local_host,
|
|
"--prefill-device-cnt",
|
|
str(self.npu_per_node * self.num_prefillers),
|
|
"--decode-device-cnt",
|
|
str(self.npu_per_node * self.num_decoders),
|
|
]
|
|
|
|
env = os.environ.copy()
|
|
assert self.nic_name is not None
|
|
env["GLOO_SOCKET_IFNAME"] = self.nic_name
|
|
|
|
logger.info(
|
|
f"Generating ranktable from command: {' '.join(map(str, cmd))}")
|
|
subprocess.run(cmd, env=env, check=True)
|
|
assert os.path.exists(
|
|
str(ranktable_path)), "failed generate ranktable.json"
|