[1/N] Refactor nightly test structure (#5479)
### What this PR does / why we need it?
This patch is a series of refactoring actions, including clarifying the
directory structure of nightly tests, refactoring the config retrieval
logic, and optimizing the workflow, etc. This is the first step:
refactoring the directory structure of nightly to make it more readable
and logical.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
0
tests/e2e/nightly/multi_node/scripts/__init__.py
Normal file
0
tests/e2e/nightly/multi_node/scripts/__init__.py
Normal file
349
tests/e2e/nightly/multi_node/scripts/multi_node_config.py
Normal file
349
tests/e2e/nightly/multi_node/scripts/multi_node_config.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import regex as re
|
||||
import yaml
|
||||
|
||||
# isort: off
|
||||
from tests.e2e.nightly.multi_node.scripts.utils import (
|
||||
CONFIG_BASE_PATH, DEFAULT_SERVER_PORT, get_all_ipv4, get_cluster_ips,
|
||||
get_net_interface, setup_logger, get_avaliable_port)
|
||||
# isort: on
|
||||
setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NodeInfo:
|
||||
index: int
|
||||
ip: str
|
||||
server_cmd: str
|
||||
envs: dict | None = None
|
||||
headless: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.ip:
|
||||
raise ValueError("NodeInfo.ip must not be empty")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ("NodeInfo(\n"
|
||||
f" index={self.index},\n"
|
||||
f" ip={self.ip},\n"
|
||||
f" headless={self.headless},\n"
|
||||
")")
|
||||
|
||||
|
||||
class DisaggregatedPrefillCfg:
|
||||
|
||||
def __init__(self, raw_cfg: dict, num_nodes: int):
|
||||
self.prefiller_indices: list[int] = raw_cfg.get(
|
||||
"prefiller_host_index", [])
|
||||
self.decoder_indices: list[int] = raw_cfg.get("decoder_host_index", [])
|
||||
|
||||
if not self.decoder_indices:
|
||||
raise RuntimeError("decoder_host_index must be provided")
|
||||
|
||||
self._validate(num_nodes)
|
||||
|
||||
self.decode_start_index = self.decoder_indices[0]
|
||||
self.num_prefillers = len(self.prefiller_indices)
|
||||
self.num_decoders = len(self.decoder_indices)
|
||||
|
||||
def _validate(self, num_nodes: int):
|
||||
overlap = set(self.prefiller_indices) & set(self.decoder_indices)
|
||||
if overlap:
|
||||
raise AssertionError(f"Prefiller and decoder overlap: {overlap}")
|
||||
|
||||
all_indices = self.prefiller_indices + self.decoder_indices
|
||||
if any(i >= num_nodes for i in all_indices):
|
||||
raise ValueError("Disaggregated prefill index out of range")
|
||||
|
||||
def is_prefiller(self, index: int) -> bool:
|
||||
return index in self.prefiller_indices
|
||||
|
||||
def is_decoder(self, index: int) -> bool:
|
||||
return index in self.decoder_indices
|
||||
|
||||
def master_ip_for_node(self, index: int, nodes: list[NodeInfo]) -> str:
|
||||
if self.is_prefiller(index):
|
||||
return nodes[0].ip
|
||||
return nodes[self.decode_start_index].ip
|
||||
|
||||
|
||||
class DistEnvBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cur_node: NodeInfo,
|
||||
master_ip: str,
|
||||
common_envs: dict,
|
||||
):
|
||||
self.cur_ip = cur_node.ip
|
||||
self.nic_name = get_net_interface(self.cur_ip)
|
||||
self.master_ip = master_ip
|
||||
|
||||
# envs
|
||||
common_envs = common_envs
|
||||
current_envs = cur_node.envs or {}
|
||||
# Node-specific envs override common envs
|
||||
self.base_envs = {**common_envs, **current_envs}
|
||||
|
||||
def build(self) -> dict:
|
||||
envs = dict(self.base_envs)
|
||||
|
||||
envs.update({
|
||||
"HCCL_IF_IP": self.cur_ip,
|
||||
"HCCL_SOCKET_IFNAME": self.nic_name,
|
||||
"GLOO_SOCKET_IFNAME": self.nic_name,
|
||||
"TP_SOCKET_IFNAME": self.nic_name,
|
||||
"LOCAL_IP": self.cur_ip,
|
||||
"NIC_NAME": self.nic_name,
|
||||
"MASTER_IP": self.master_ip,
|
||||
})
|
||||
|
||||
return {k: str(v) for k, v in envs.items()}
|
||||
|
||||
|
||||
class ProxyLauncher:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: list[NodeInfo],
|
||||
envs: dict,
|
||||
proxy_port: int,
|
||||
cur_index: int,
|
||||
disagg_cfg: DisaggregatedPrefillCfg | None = None,
|
||||
):
|
||||
self.nodes = nodes
|
||||
self.cfg = disagg_cfg
|
||||
self.server_port = envs.get("SERVER_PORT", DEFAULT_SERVER_PORT)
|
||||
self.proxy_port = proxy_port
|
||||
self.proxy_script = envs.get(
|
||||
"DISAGGREGATED_PREFILL_PROXY_SCRIPT",
|
||||
'examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py'
|
||||
)
|
||||
self.envs = envs
|
||||
self.is_master = cur_index == 0
|
||||
self.cur_ip = nodes[cur_index].ip
|
||||
self.process: Optional[subprocess.Popen[bytes]] = None
|
||||
|
||||
def __enter__(self):
|
||||
if not self.is_master or self.cfg is None:
|
||||
logger.info("Not launching proxy on non-master node")
|
||||
return self
|
||||
prefiller_ips = [self.nodes[i].ip for i in self.cfg.prefiller_indices]
|
||||
decoder_ips = [self.nodes[i].ip for i in self.cfg.decoder_indices]
|
||||
|
||||
cmd = [
|
||||
"python",
|
||||
self.proxy_script,
|
||||
"--host",
|
||||
self.cur_ip,
|
||||
"--port",
|
||||
str(self.proxy_port),
|
||||
"--prefiller-hosts",
|
||||
*prefiller_ips,
|
||||
"--prefiller-ports",
|
||||
*[str(self.server_port)] * len(prefiller_ips),
|
||||
"--decoder-hosts",
|
||||
*decoder_ips,
|
||||
"--decoder-ports",
|
||||
*[str(self.server_port)] * len(decoder_ips),
|
||||
]
|
||||
|
||||
logger.info("Launching proxy: %s", " ".join(cmd))
|
||||
self.process = subprocess.Popen(cmd, env={**os.environ, **self.envs})
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if not self.process:
|
||||
return
|
||||
logger.info("Stopping proxy server...")
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
|
||||
|
||||
class MultiNodeConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
test_name: str,
|
||||
nodes: list[NodeInfo],
|
||||
npu_per_node: int,
|
||||
envs: dict,
|
||||
disaggregated_prefill: dict | None,
|
||||
perf_cmd: str | None,
|
||||
acc_cmd: str | None,
|
||||
):
|
||||
self.model = model
|
||||
self.test_name = test_name
|
||||
self.nodes = nodes
|
||||
self.npu_per_node = npu_per_node
|
||||
self.perf_cmd = perf_cmd
|
||||
self.acc_cmd = acc_cmd
|
||||
|
||||
self.cur_index = self._resolve_cur_index()
|
||||
self.cur_node = self.nodes[self.cur_index]
|
||||
|
||||
self.disagg_cfg = (DisaggregatedPrefillCfg(disaggregated_prefill,
|
||||
len(nodes))
|
||||
if disaggregated_prefill else None)
|
||||
|
||||
master_ip = (self.disagg_cfg.master_ip_for_node(
|
||||
self.cur_index, self.nodes)
|
||||
if self.disagg_cfg else self.nodes[0].ip)
|
||||
self.proxy_port = get_avaliable_port()
|
||||
|
||||
self.envs = DistEnvBuilder(
|
||||
cur_node=self.cur_node,
|
||||
master_ip=master_ip,
|
||||
common_envs=envs,
|
||||
).build()
|
||||
logger.info("Node %d envs: %s", self.cur_index, self.envs)
|
||||
|
||||
self.server_cmd = self._expand_env(self.cur_node.server_cmd)
|
||||
|
||||
def _resolve_cur_index(self) -> int:
|
||||
if (idx := os.environ.get("LWS_WORKER_INDEX")):
|
||||
return int(idx)
|
||||
|
||||
local_ips = get_all_ipv4()
|
||||
for i, node in enumerate(self.nodes):
|
||||
if node.ip in local_ips:
|
||||
return i
|
||||
|
||||
raise RuntimeError("Unable to determine current node index")
|
||||
|
||||
def _expand_env(self, cmd: str) -> str:
|
||||
pattern = re.compile(r"\$(\w+)|\$\{(\w+)\}")
|
||||
|
||||
def repl(m):
|
||||
key = m.group(1) or m.group(2)
|
||||
return self.envs.get(key, m.group(0))
|
||||
|
||||
return pattern.sub(repl, cmd)
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return len(self.nodes) * self.npu_per_node
|
||||
|
||||
@property
|
||||
def is_master(self) -> bool:
|
||||
return self.cur_index == 0
|
||||
|
||||
@property
|
||||
def server_port(self) -> int:
|
||||
return self.envs.get("SERVER_PORT", DEFAULT_SERVER_PORT)
|
||||
|
||||
@property
|
||||
def master_ip(self) -> str:
|
||||
return self.nodes[0].ip
|
||||
|
||||
@property
|
||||
def benchmark_endpoint(self) -> tuple[str, int]:
|
||||
"""
|
||||
Endpoint used by benchmark clients.
|
||||
"""
|
||||
master_ip = self.nodes[0].ip
|
||||
server_port = self.envs.get("SERVER_PORT", DEFAULT_SERVER_PORT)
|
||||
if self.disagg_cfg:
|
||||
return master_ip, self.proxy_port
|
||||
return master_ip, server_port
|
||||
|
||||
|
||||
class MultiNodeConfigLoader:
|
||||
"""Load MultiNodeConfig from yaml file."""
|
||||
|
||||
DEFAULT_CONFIG_NAME = "DeepSeek-V3.yaml"
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: Optional[str] = None) -> MultiNodeConfig:
|
||||
config = cls._load_yaml(yaml_path)
|
||||
cls._validate_root(config)
|
||||
|
||||
nodes = cls._parse_nodes(config)
|
||||
benchmarks = cls._parse_benchmarks(config)
|
||||
|
||||
return MultiNodeConfig(
|
||||
model=config["model"],
|
||||
test_name=config.get("test_name", "untitled_test"),
|
||||
nodes=nodes,
|
||||
npu_per_node=config.get("npu_per_node", 16),
|
||||
envs=config.get("env_common", {}),
|
||||
disaggregated_prefill=config.get("disaggregated_prefill"),
|
||||
perf_cmd=benchmarks.get("perf"),
|
||||
acc_cmd=benchmarks.get("acc"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_yaml(cls, yaml_path: Optional[str]) -> dict:
|
||||
if not yaml_path:
|
||||
yaml_path = os.getenv("CONFIG_YAML_PATH", cls.DEFAULT_CONFIG_NAME)
|
||||
|
||||
full_path = os.path.join(CONFIG_BASE_PATH, yaml_path)
|
||||
logger.info("Loading config yaml: %s", full_path)
|
||||
|
||||
with open(full_path, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
@staticmethod
|
||||
def _validate_root(cfg: dict):
|
||||
required = [
|
||||
"model", "deployment", "num_nodes", "npu_per_node", "env_common",
|
||||
"benchmarks"
|
||||
]
|
||||
missing = [k for k in required if k not in cfg]
|
||||
if missing:
|
||||
raise KeyError(f"Missing required config fields: {missing}")
|
||||
|
||||
@classmethod
|
||||
def _parse_nodes(cls, cfg: dict) -> list[NodeInfo]:
|
||||
num_nodes = cfg["num_nodes"]
|
||||
deployments = cfg["deployment"]
|
||||
|
||||
if len(deployments) != num_nodes:
|
||||
raise AssertionError(
|
||||
f"deployment size ({len(deployments)}) != num_nodes ({num_nodes})"
|
||||
)
|
||||
|
||||
cluster_ips = cls._resolve_cluster_ips(cfg, num_nodes)
|
||||
|
||||
nodes: list[NodeInfo] = []
|
||||
for idx, deploy in enumerate(deployments):
|
||||
cmd = deploy.get("server_cmd", "")
|
||||
envs = deploy.get("envs", {})
|
||||
nodes.append(
|
||||
NodeInfo(
|
||||
index=idx,
|
||||
ip=cluster_ips[idx],
|
||||
server_cmd=cmd,
|
||||
envs=envs,
|
||||
headless="--headless" in cmd,
|
||||
))
|
||||
return nodes
|
||||
|
||||
@staticmethod
|
||||
def _parse_benchmarks(cfg: dict) -> dict:
|
||||
benchmarks = cfg.get("benchmarks") or {}
|
||||
return benchmarks
|
||||
|
||||
@staticmethod
|
||||
def _resolve_cluster_ips(cfg: dict, num_nodes: int) -> list[str]:
|
||||
if "cluster_hosts" in cfg and cfg["cluster_hosts"]:
|
||||
ips = cfg["cluster_hosts"]
|
||||
if len(ips) != num_nodes:
|
||||
raise AssertionError("cluster_hosts size mismatch")
|
||||
return ips
|
||||
|
||||
logger.info("Resolving cluster IPs via DNS...")
|
||||
return get_cluster_ips(num_nodes)
|
||||
@@ -9,11 +9,15 @@ RED="\033[0;31m"
|
||||
NC="\033[0m" # No Color
|
||||
|
||||
# Configuration
|
||||
LOG_DIR="/root/.cache/tests/logs"
|
||||
OVERWRITE_LOGS=true
|
||||
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
|
||||
# Home path for aisbench
|
||||
export BENCHMARK_HOME=${WORKSPACE}/vllm-ascend/benchmark
|
||||
|
||||
# Logging configurations
|
||||
export VLLM_LOGGING_LEVEL="INFO"
|
||||
# Reduce glog verbosity for mooncake
|
||||
export GLOG_minloglevel=1
|
||||
# Set transformers to offline mode to avoid downloading models during tests
|
||||
export TRANSFORMERS_OFFLINE="1"
|
||||
|
||||
# Function to print section headers
|
||||
@@ -131,7 +135,7 @@ kill_npu_processes() {
|
||||
run_tests_with_log() {
|
||||
set +e
|
||||
kill_npu_processes
|
||||
pytest -sv --show-capture=no tests/e2e/nightly/multi_node/test_multi_node.py
|
||||
pytest -sv --show-capture=no tests/e2e/nightly/multi_node/scripts/test_multi_node.py
|
||||
ret=$?
|
||||
set -e
|
||||
if [ "$LWS_WORKER_INDEX" -eq 0 ]; then
|
||||
|
||||
46
tests/e2e/nightly/multi_node/scripts/test_multi_node.py
Normal file
46
tests/e2e/nightly/multi_node/scripts/test_multi_node.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import RemoteOpenAIServer
|
||||
from tests.e2e.nightly.multi_node.scripts.multi_node_config import (
|
||||
MultiNodeConfigLoader, ProxyLauncher)
|
||||
from tools.aisbench import run_aisbench_cases
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_node() -> None:
|
||||
config = MultiNodeConfigLoader.from_yaml()
|
||||
|
||||
with ProxyLauncher(
|
||||
nodes=config.nodes,
|
||||
disagg_cfg=config.disagg_cfg,
|
||||
envs=config.envs,
|
||||
proxy_port=config.proxy_port,
|
||||
cur_index=config.cur_index,
|
||||
) as proxy:
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model=config.model,
|
||||
vllm_serve_args=config.server_cmd,
|
||||
server_port=config.server_port,
|
||||
server_host=config.master_ip,
|
||||
env_dict=config.envs,
|
||||
auto_port=False,
|
||||
proxy_port=proxy.proxy_port,
|
||||
disaggregated_prefill=config.disagg_cfg,
|
||||
nodes_info=config.nodes,
|
||||
max_wait_seconds=2800,
|
||||
) as server:
|
||||
|
||||
host, port = config.benchmark_endpoint
|
||||
|
||||
if config.is_master:
|
||||
run_aisbench_cases(
|
||||
model=config.model,
|
||||
port=port,
|
||||
aisbench_cases=[config.acc_cmd, config.perf_cmd],
|
||||
host_ip=host,
|
||||
)
|
||||
else:
|
||||
# We should keep listening on the master node's server url determining when to exit.
|
||||
server.hang_until_terminated(
|
||||
f"http://{host}:{config.server_port}/health")
|
||||
149
tests/e2e/nightly/multi_node/scripts/utils.py
Normal file
149
tests/e2e/nightly/multi_node/scripts/utils.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
import psutil
|
||||
|
||||
DISAGGEGATED_PREFILL_PORT = 5333
|
||||
CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/"
|
||||
DEFAULT_SERVER_PORT = 8080
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_env(env_dict):
|
||||
old_env = {}
|
||||
for k, v in env_dict.items():
|
||||
old_env[k] = os.environ.get(k)
|
||||
os.environ[k] = str(v)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in old_env.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
def dns_resolver(retries: int = 240, base_delay: float = 0.5):
|
||||
# We should resolve DNS with retries to avoid transient network issues.
|
||||
# When the pod is just started, DNS resolution may fail.
|
||||
def resolve(dns: str):
|
||||
delay = base_delay
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
return socket.gethostbyname(dns)
|
||||
except socket.gaierror:
|
||||
if attempt == retries - 1:
|
||||
raise
|
||||
time.sleep(delay)
|
||||
delay = min(delay * 1.5, 5)
|
||||
|
||||
return resolve
|
||||
|
||||
|
||||
def get_cluster_dns_list(world_size: int) -> List[str]:
|
||||
if world_size < 1:
|
||||
raise ValueError(f"world_size must be >= 1, got {world_size}")
|
||||
|
||||
leader_dns = os.getenv("LWS_LEADER_ADDRESS")
|
||||
if not leader_dns:
|
||||
raise RuntimeError(
|
||||
"environment variable LWS_LEADER_ADDRESS is not set")
|
||||
|
||||
# Expected format:
|
||||
# <leader-name>.<group-name>.<namespace>
|
||||
parts = leader_dns.split(".")
|
||||
if len(parts) < 3:
|
||||
raise ValueError(f"invalid leader DNS format: {leader_dns}")
|
||||
|
||||
leader_name, group_name, namespace = parts[0], parts[1], parts[2]
|
||||
|
||||
worker_dns_list = [
|
||||
f"{leader_name}-{idx}.{group_name}.{namespace}"
|
||||
for idx in range(1, world_size)
|
||||
]
|
||||
|
||||
return [leader_dns, *worker_dns_list]
|
||||
|
||||
|
||||
def get_cluster_ips(word_size: int = 2) -> list[str]:
|
||||
resolver = dns_resolver()
|
||||
return [resolver(dns) for dns in get_cluster_dns_list(word_size)]
|
||||
|
||||
|
||||
def get_avaliable_port(start_port: int = 6000, end_port: int = 7000) -> int:
|
||||
import socket
|
||||
for port in range(start_port, end_port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
continue
|
||||
raise RuntimeError("No available port found")
|
||||
|
||||
|
||||
def get_cur_ip(retries: int = 20, base_delay: float = 0.5):
|
||||
"""
|
||||
Returns the pod/machine's primary IP address with retry.
|
||||
This is necessary because network interfaces may not be ready
|
||||
immediately after container startup.
|
||||
"""
|
||||
delay = base_delay
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
# Best method: UDP trick (doesn't actually send packets)
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("8.8.8.8", 80))
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
# fallback: hostname resolution
|
||||
try:
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
except Exception:
|
||||
if attempt == retries - 1:
|
||||
raise RuntimeError("Failed to determine local IP address")
|
||||
time.sleep(delay)
|
||||
delay = min(delay * 1.5, 5)
|
||||
|
||||
|
||||
def get_net_interface(ip: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns specified IP's inetwork interface.
|
||||
If no IP is provided, uses the first from hostname -I.
|
||||
"""
|
||||
if ip is None:
|
||||
ip = get_cur_ip()
|
||||
|
||||
for iface, addrs in psutil.net_if_addrs().items():
|
||||
for addr in addrs:
|
||||
if addr.family == socket.AF_INET and addr.address == ip:
|
||||
return iface
|
||||
raise RuntimeError(f"No network interface found for IP {ip}")
|
||||
|
||||
|
||||
def get_all_ipv4():
|
||||
"""get all the ipv4 address for current node"""
|
||||
ipv4s = set()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
for info in socket.getaddrinfo(hostname, None, family=socket.AF_INET):
|
||||
ipv4s.add(info[4][0])
|
||||
|
||||
ipv4s.add("127.0.0.1")
|
||||
|
||||
return list(ipv4s)
|
||||
|
||||
|
||||
def setup_logger():
|
||||
"""Setup logging configuration."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
Reference in New Issue
Block a user