[CI][Doc] Optimize multi-node CI (#3565)
### What this PR does / why we need it?
This pull request mainly do the following things:
1. Add a doc for multi-node CI, The main content is the mechanism
principle and how to contribute
2. Simplify the config yaml for more developer-friendly
3. Optimized the mooncake installation script to prevent accidental
failures during installation
4. Fix the workflow to ensure the kubernetes can be apply correctly
5. Add Qwen3-235B-W8A8 disaggregated_prefill test
6. Add GLM-4.5 multi dp test
7. Add 2p1d 4nodes disaggregated_prefill test
8. Refactor nightly tests
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import regex as re
|
||||
@@ -15,6 +16,16 @@ from tests.e2e.nightly.multi_node.config.utils import (get_avaliable_port,
|
||||
setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
DISAGGREGATED_PREFILL_PROXY_SCRIPT = "examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py"
|
||||
DISAGGEGATED_PREFILL_PORT = 5333
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeInfo:
|
||||
index: int
|
||||
ip: str
|
||||
server_cmd: str
|
||||
headless: bool
|
||||
server_port: int
|
||||
|
||||
|
||||
class MultiNodeConfig:
|
||||
@@ -22,38 +33,50 @@ class MultiNodeConfig:
|
||||
def __init__(self,
|
||||
model: str,
|
||||
test_name: str,
|
||||
num_nodes: int = 2,
|
||||
npu_per_node: int = 16,
|
||||
server_port: int = 8080,
|
||||
headless: bool = False,
|
||||
disaggregated_prefill: Optional[dict] = None,
|
||||
envs: Optional[dict] = None,
|
||||
server_cmd: str = "",
|
||||
nodes_info: Optional[list[NodeInfo]] = None,
|
||||
perf_cmd: Optional[str] = None,
|
||||
acc_cmd: Optional[str] = None):
|
||||
self.test_name = test_name
|
||||
self.model = model
|
||||
self.num_nodes = num_nodes
|
||||
self.nodes_info = nodes_info or []
|
||||
self.num_nodes = len(self.nodes_info)
|
||||
self.npu_per_node = npu_per_node
|
||||
self.envs = envs if envs is not None else {}
|
||||
self.server_port = server_port
|
||||
if disaggregated_prefill:
|
||||
self.proxy_port = get_avaliable_port()
|
||||
self.headless = headless
|
||||
self.server_cmd = server_cmd
|
||||
self.envs = envs if envs is not None else {}
|
||||
self.proxy_port = get_avaliable_port()
|
||||
self.perf_cmd = perf_cmd
|
||||
self.acc_cmd = acc_cmd
|
||||
assert perf_cmd is not None, "perf_cmd must be provided"
|
||||
assert acc_cmd is not None, "acc_cmd must be provided"
|
||||
assert server_cmd is not None, "server_cmd must be provided"
|
||||
|
||||
self.cur_index = os.getenv("LWS_WORKER_INDEX", 0)
|
||||
self.cur_index = int(os.getenv("LWS_WORKER_INDEX", 0))
|
||||
self.cur_ip = get_cur_ip()
|
||||
self.nic_name = get_net_interface(self.cur_ip)
|
||||
self.cluster_ips = get_cluster_ips(num_nodes)
|
||||
self.cluster_ips = get_cluster_ips(self.num_nodes)
|
||||
self.cur_node_info: NodeInfo = self.nodes_info[self.cur_index]
|
||||
self.disaggregated_prefill = disaggregated_prefill
|
||||
self._init_disaggregated_prefill()
|
||||
|
||||
self._init_dist_env()
|
||||
self.server_cmd = self._expand_env_vars(self.server_cmd, self.envs)
|
||||
self.server_cmd = self._expand_env_vars(self.cur_node_info.server_cmd,
|
||||
self.envs)
|
||||
|
||||
def _init_disaggregated_prefill(self):
|
||||
if self.disaggregated_prefill:
|
||||
decode_host_index = self.disaggregated_prefill.get(
|
||||
"decoder_host_index")
|
||||
if not decode_host_index:
|
||||
raise RuntimeError("got empty decode_host_index")
|
||||
self.decode_start_index: int = decode_host_index[0]
|
||||
self.num_prefillers = self.decode_start_index
|
||||
self.num_decoders = self.num_nodes - self.num_prefillers
|
||||
if self.disaggregated_prefill.get(
|
||||
"ranktable_gen_path") is not None:
|
||||
self._gen_ranktable()
|
||||
|
||||
def _init_dist_env(self):
|
||||
self.envs["HCCL_IF_IP"] = self.cur_ip
|
||||
@@ -62,7 +85,17 @@ class MultiNodeConfig:
|
||||
self.envs["HCCL_SOCKET_IFNAME"] = self.nic_name
|
||||
self.envs["LOCAL_IP"] = self.cur_ip
|
||||
self.envs["NIC_NAME"] = self.nic_name
|
||||
self.envs["MASTER_IP"] = self.cluster_ips[0]
|
||||
|
||||
if self.disaggregated_prefill:
|
||||
self.envs[
|
||||
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH"] = self.disaggregated_prefill.get(
|
||||
"ranktable_path")
|
||||
if self.cur_index < self.decode_start_index:
|
||||
self.envs["MASTER_IP"] = self.cluster_ips[0]
|
||||
else:
|
||||
self.envs["MASTER_IP"] = self.cluster_ips[
|
||||
self.decode_start_index]
|
||||
|
||||
ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
|
||||
self.envs[
|
||||
"LD_LIBRARY_PATH"] = f"{ascend_path}:{self.envs.get('LD_LIBRARY_PATH', os.environ.get('LD_LIBRARY_PATH', ''))}"
|
||||
@@ -172,15 +205,21 @@ class MultiNodeConfig:
|
||||
deployments = config_data.get("deployment", [])
|
||||
assert len(deployments) == num_nodes, \
|
||||
f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})"
|
||||
for deployment in deployments:
|
||||
if deployment.get("local_index") == int(
|
||||
os.getenv("LWS_WORKER_INDEX", 0)):
|
||||
envs_extend = deployment.get("env_extend", {})
|
||||
if envs_extend:
|
||||
envs.update(envs_extend)
|
||||
server_cmd = deployment.get("server_cmd")
|
||||
headless = deployment.get("headless", False)
|
||||
break
|
||||
|
||||
cluster_ips = get_cluster_ips(num_nodes)
|
||||
nodes_info = []
|
||||
|
||||
for index, deployment in enumerate(deployments):
|
||||
# after assert len(deployments) == num_nodes, we can assume that this will must have a match
|
||||
server_cmd = deployment.get("server_cmd", "")
|
||||
headless = "--headless" in server_cmd
|
||||
nodes_info.append(
|
||||
NodeInfo(ip=cluster_ips[index],
|
||||
index=index,
|
||||
headless=headless,
|
||||
server_port=server_port,
|
||||
server_cmd=server_cmd))
|
||||
|
||||
benchmarks = config_data.get("benchmarks", {})
|
||||
assert benchmarks is not None, "benchmarks must be provided"
|
||||
perf_cmd = benchmarks["perf"]
|
||||
@@ -188,13 +227,11 @@ class MultiNodeConfig:
|
||||
|
||||
return cls(model=model,
|
||||
test_name=test_name,
|
||||
num_nodes=num_nodes,
|
||||
npu_per_node=npu_per_node,
|
||||
envs=envs,
|
||||
server_port=server_port,
|
||||
headless=headless,
|
||||
disaggregated_prefill=disaggregated_prefill,
|
||||
server_cmd=server_cmd,
|
||||
nodes_info=nodes_info,
|
||||
perf_cmd=perf_cmd,
|
||||
acc_cmd=acc_cmd)
|
||||
|
||||
@@ -204,4 +241,52 @@ class MultiNodeConfig:
|
||||
|
||||
@property
|
||||
def is_master(self):
|
||||
return int(self.cur_index) == 0
|
||||
return self.cur_index == 0
|
||||
|
||||
def _gen_ranktable(self):
|
||||
cluster_ip = self.cluster_ips
|
||||
assert len(cluster_ip) > 0
|
||||
nnodes = self.num_nodes
|
||||
node_rank = self.cur_index
|
||||
master_addr = cluster_ip[0]
|
||||
master_port = DISAGGEGATED_PREFILL_PORT
|
||||
assert self.disaggregated_prefill is not None
|
||||
ranktable_gen_path = self.disaggregated_prefill.get(
|
||||
"ranktable_gen_path")
|
||||
ranktable_path = self.disaggregated_prefill.get("ranktable_path")
|
||||
assert ranktable_gen_path is not None and ranktable_path is not None
|
||||
if os.path.exists(str(ranktable_path)):
|
||||
return
|
||||
|
||||
local_host = self.cur_ip
|
||||
|
||||
cmd = [
|
||||
"torchrun",
|
||||
"--nproc_per_node",
|
||||
"1",
|
||||
"--nnodes",
|
||||
str(nnodes),
|
||||
"--node_rank",
|
||||
str(node_rank),
|
||||
"--master_addr",
|
||||
master_addr,
|
||||
"--master_port",
|
||||
str(master_port),
|
||||
ranktable_gen_path,
|
||||
"--ranktable-path",
|
||||
str(ranktable_path),
|
||||
"--local-host",
|
||||
local_host,
|
||||
"--prefill-device-cnt",
|
||||
str(self.npu_per_node * self.num_prefillers),
|
||||
"--decode-device-cnt",
|
||||
str(self.npu_per_node * self.num_decoders),
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
assert self.nic_name is not None
|
||||
env["GLOO_SOCKET_IFNAME"] = self.nic_name
|
||||
|
||||
subprocess.run(cmd, env=env, check=True)
|
||||
assert os.path.exists(
|
||||
str(ranktable_path)), "failed generate ranktable.json"
|
||||
|
||||
Reference in New Issue
Block a user