[CI][Nightly] Support local debugging for multi-node CI test cases (#4489)
### What this PR does / why we need it? This patch mainly doing the following things: 1. Make k8s/lws optional for multi-node testing, allowing developers to run multi-node tests locally by actively passing in the IP addresses of all nodes. 2. Allows passing a custom proxy script path in the config file to load the proxy. - vLLM version: v0.11.2 --------- Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -20,6 +20,9 @@ env_common:
|
|||||||
HCCL_BUFFSIZE: 1024
|
HCCL_BUFFSIZE: 1024
|
||||||
SERVER_PORT: 8080
|
SERVER_PORT: 8080
|
||||||
NUMEXPR_MAX_THREADS: 128
|
NUMEXPR_MAX_THREADS: 128
|
||||||
|
DISAGGREGATED_PREFILL_PROXY_SCRIPT: "examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py"
|
||||||
|
# For None kubernetes deployment, list the IPs of all nodes used in order as follow
|
||||||
|
# cluster_hosts: []
|
||||||
disaggregated_prefill:
|
disaggregated_prefill:
|
||||||
enabled: true
|
enabled: true
|
||||||
prefiller_host_index: [0]
|
prefiller_host_index: [0]
|
||||||
|
|||||||
@@ -7,15 +7,14 @@ from typing import Optional
|
|||||||
import regex as re
|
import regex as re
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from tests.e2e.nightly.multi_node.config.utils import (get_avaliable_port,
|
from tests.e2e.nightly.multi_node.config.utils import (get_all_ipv4,
|
||||||
|
get_avaliable_port,
|
||||||
get_cluster_ips,
|
get_cluster_ips,
|
||||||
get_cur_ip,
|
|
||||||
get_net_interface,
|
get_net_interface,
|
||||||
setup_logger)
|
setup_logger)
|
||||||
|
|
||||||
setup_logger()
|
setup_logger()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
DISAGGREGATED_PREFILL_PROXY_SCRIPT = "examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py"
|
|
||||||
DISAGGEGATED_PREFILL_PORT = 5333
|
DISAGGEGATED_PREFILL_PORT = 5333
|
||||||
CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/models/"
|
CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/models/"
|
||||||
|
|
||||||
@@ -28,22 +27,33 @@ class NodeInfo:
|
|||||||
headless: bool
|
headless: bool
|
||||||
server_port: int
|
server_port: int
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (f"NodeInfo:\n"
|
||||||
|
f" index={self.index}\n"
|
||||||
|
f" ip={self.ip}\n"
|
||||||
|
f" server_port={self.server_port}\n"
|
||||||
|
f" headless={self.headless}")
|
||||||
|
|
||||||
|
|
||||||
class MultiNodeConfig:
|
class MultiNodeConfig:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: str,
|
model: str,
|
||||||
test_name: str,
|
test_name: str,
|
||||||
|
nodes_info: list[NodeInfo],
|
||||||
npu_per_node: int = 16,
|
npu_per_node: int = 16,
|
||||||
server_port: int = 8080,
|
server_port: int = 8080,
|
||||||
disaggregated_prefill: Optional[dict] = None,
|
disaggregated_prefill: Optional[dict] = None,
|
||||||
envs: Optional[dict] = None,
|
envs: Optional[dict] = None,
|
||||||
nodes_info: Optional[list[NodeInfo]] = None,
|
|
||||||
perf_cmd: Optional[str] = None,
|
perf_cmd: Optional[str] = None,
|
||||||
acc_cmd: Optional[str] = None):
|
acc_cmd: Optional[str] = None):
|
||||||
self.test_name = test_name
|
self.test_name = test_name
|
||||||
self.model = model
|
self.model = model
|
||||||
self.nodes_info = nodes_info or []
|
self.nodes_info = nodes_info
|
||||||
|
# We assume the first index of nodes as the master
|
||||||
|
# NOTE: this may be different in the scenarios like disaggregated prefill
|
||||||
|
# There may be multi groups of nodes, and the master of each group may be different
|
||||||
|
self.master_ip = self.nodes_info[0].ip
|
||||||
self.num_nodes = len(self.nodes_info)
|
self.num_nodes = len(self.nodes_info)
|
||||||
self.npu_per_node = npu_per_node
|
self.npu_per_node = npu_per_node
|
||||||
self.server_port = server_port
|
self.server_port = server_port
|
||||||
@@ -52,18 +62,48 @@ class MultiNodeConfig:
|
|||||||
self.perf_cmd = perf_cmd
|
self.perf_cmd = perf_cmd
|
||||||
self.acc_cmd = acc_cmd
|
self.acc_cmd = acc_cmd
|
||||||
|
|
||||||
self.cur_index = int(os.getenv("LWS_WORKER_INDEX", 0))
|
|
||||||
self.cur_ip = get_cur_ip()
|
|
||||||
self.nic_name = get_net_interface(self.cur_ip)
|
|
||||||
self.cluster_ips = get_cluster_ips(self.num_nodes)
|
|
||||||
self.cur_node_info: NodeInfo = self.nodes_info[self.cur_index]
|
|
||||||
self.disaggregated_prefill = disaggregated_prefill
|
self.disaggregated_prefill = disaggregated_prefill
|
||||||
self._init_disaggregated_prefill()
|
self._init_disaggregated_prefill()
|
||||||
|
|
||||||
self._init_dist_env()
|
self._init_dist_env()
|
||||||
self.server_cmd = self._expand_env_vars(self.cur_node_info.server_cmd,
|
self.server_cmd = self._expand_env_vars(self.node_info.server_cmd,
|
||||||
self.envs)
|
self.envs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_ip(self):
|
||||||
|
return self.nodes_info[self.cur_index].ip
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nic_name(self):
|
||||||
|
return get_net_interface(self.cur_ip)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_info(self):
|
||||||
|
return self.nodes_info[self.cur_index]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_index(self):
|
||||||
|
# 1. Try to read worker index from K8s environment variable
|
||||||
|
worker_index = os.environ.get("LWS_WORKER_INDEX")
|
||||||
|
if worker_index:
|
||||||
|
return int(worker_index)
|
||||||
|
|
||||||
|
# 2. Fallback: match local IP against cluster IP list
|
||||||
|
cluster_ips = [node.ip for node in self.nodes_info]
|
||||||
|
cluster_ip_set = set(cluster_ips)
|
||||||
|
|
||||||
|
cur_ips = get_all_ipv4()
|
||||||
|
|
||||||
|
for ip in cur_ips:
|
||||||
|
if ip in cluster_ip_set:
|
||||||
|
return cluster_ips.index(ip)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not determine current node index: no matching IP.\n"
|
||||||
|
f"Local machine IPs: {cur_ips}\n"
|
||||||
|
f"Cluster IPs: {cluster_ips}\n"
|
||||||
|
"Please check your config file or network settings.")
|
||||||
|
|
||||||
def _init_disaggregated_prefill(self):
|
def _init_disaggregated_prefill(self):
|
||||||
if self.disaggregated_prefill:
|
if self.disaggregated_prefill:
|
||||||
decode_host_index = self.disaggregated_prefill.get(
|
decode_host_index = self.disaggregated_prefill.get(
|
||||||
@@ -85,15 +125,17 @@ class MultiNodeConfig:
|
|||||||
self.envs["LOCAL_IP"] = self.cur_ip
|
self.envs["LOCAL_IP"] = self.cur_ip
|
||||||
self.envs["NIC_NAME"] = self.nic_name
|
self.envs["NIC_NAME"] = self.nic_name
|
||||||
|
|
||||||
master_ip = self.cluster_ips[0]
|
master_ip = self.master_ip
|
||||||
if self.disaggregated_prefill:
|
if self.disaggregated_prefill:
|
||||||
self.envs[
|
self.envs[
|
||||||
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH"] = self.disaggregated_prefill.get(
|
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH"] = self.disaggregated_prefill.get(
|
||||||
"ranktable_path")
|
"ranktable_path")
|
||||||
if self.cur_index < self.decode_start_index:
|
if self.cur_index < self.decode_start_index:
|
||||||
master_ip = self.cluster_ips[0]
|
# For prefiller nodes, use the default master ip(index==0) as DP master
|
||||||
|
master_ip = self.master_ip
|
||||||
else:
|
else:
|
||||||
master_ip = self.cluster_ips[self.decode_start_index]
|
# For decoder nodes, use the first decoder node as DP master
|
||||||
|
master_ip = self.nodes_info[self.decode_start_index].ip
|
||||||
|
|
||||||
self.envs["MASTER_IP"] = master_ip
|
self.envs["MASTER_IP"] = master_ip
|
||||||
ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
|
ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
|
||||||
@@ -139,8 +181,9 @@ class MultiNodeConfig:
|
|||||||
assert not common_indices, f"Common indices found: {common_indices}"
|
assert not common_indices, f"Common indices found: {common_indices}"
|
||||||
assert o.proxy_port is not None, "proxy_port must be set"
|
assert o.proxy_port is not None, "proxy_port must be set"
|
||||||
|
|
||||||
prefiller_ips = [o.cluster_ips[i] for i in prefiller_indices]
|
cluster_ips = [node.ip for node in o.nodes_info]
|
||||||
decoder_ips = [o.cluster_ips[i] for i in decoder_indices]
|
prefiller_ips = [cluster_ips[i] for i in prefiller_indices]
|
||||||
|
decoder_ips = [cluster_ips[i] for i in decoder_indices]
|
||||||
prefiller_ports_list = [str(o.server_port)] * len(prefiller_ips)
|
prefiller_ports_list = [str(o.server_port)] * len(prefiller_ips)
|
||||||
decoder_ports_list = [str(o.server_port)] * len(decoder_ips)
|
decoder_ports_list = [str(o.server_port)] * len(decoder_ips)
|
||||||
|
|
||||||
@@ -204,7 +247,12 @@ class MultiNodeConfig:
|
|||||||
deployments = config_data.get("deployment", [])
|
deployments = config_data.get("deployment", [])
|
||||||
assert len(deployments) == num_nodes, \
|
assert len(deployments) == num_nodes, \
|
||||||
f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})"
|
f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})"
|
||||||
|
cluster_ips = config_data.get("cluster_hosts", None)
|
||||||
|
if cluster_ips:
|
||||||
|
assert len(cluster_ips) == num_nodes, \
|
||||||
|
"Must provide cluster_ips for all nodes if it is explicitly specified."
|
||||||
|
else:
|
||||||
|
logger.info("Resolving cluster IPs via DNS...")
|
||||||
cluster_ips = get_cluster_ips(num_nodes)
|
cluster_ips = get_cluster_ips(num_nodes)
|
||||||
nodes_info = []
|
nodes_info = []
|
||||||
|
|
||||||
@@ -243,7 +291,7 @@ class MultiNodeConfig:
|
|||||||
return self.cur_index == 0
|
return self.cur_index == 0
|
||||||
|
|
||||||
def _gen_ranktable(self):
|
def _gen_ranktable(self):
|
||||||
cluster_ip = self.cluster_ips
|
cluster_ip = [nodes.ip for nodes in self.nodes_info]
|
||||||
assert len(cluster_ip) > 0
|
assert len(cluster_ip) > 0
|
||||||
nnodes = self.num_nodes
|
nnodes = self.num_nodes
|
||||||
node_rank = self.cur_index
|
node_rank = self.cur_index
|
||||||
|
|||||||
@@ -107,6 +107,19 @@ def get_net_interface(ip: Optional[str] = None) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_ipv4():
|
||||||
|
"""get all the ipv4 address for current node"""
|
||||||
|
ipv4s = set()
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
|
for info in socket.getaddrinfo(hostname, None, family=socket.AF_INET):
|
||||||
|
ipv4s.add(info[4][0])
|
||||||
|
|
||||||
|
ipv4s.add("127.0.0.1")
|
||||||
|
|
||||||
|
return list(ipv4s)
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
def setup_logger():
|
||||||
"""Setup logging configuration."""
|
"""Setup logging configuration."""
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ from modelscope import snapshot_download # type: ignore
|
|||||||
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
||||||
|
|
||||||
from tests.e2e.conftest import RemoteOpenAIServer
|
from tests.e2e.conftest import RemoteOpenAIServer
|
||||||
from tests.e2e.nightly.multi_node.config.multi_node_config import (
|
from tests.e2e.nightly.multi_node.config.multi_node_config import \
|
||||||
DISAGGREGATED_PREFILL_PROXY_SCRIPT, MultiNodeConfig)
|
MultiNodeConfig
|
||||||
from tools.aisbench import run_aisbench_cases
|
from tools.aisbench import run_aisbench_cases
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
@@ -100,8 +100,10 @@ async def test_multi_node() -> None:
|
|||||||
disaggregated_prefill = config.disaggregated_prefill
|
disaggregated_prefill = config.disaggregated_prefill
|
||||||
server_port = config.server_port
|
server_port = config.server_port
|
||||||
proxy_port = config.proxy_port
|
proxy_port = config.proxy_port
|
||||||
server_host = config.cluster_ips[0]
|
server_host = config.node_info.ip
|
||||||
with config.launch_server_proxy(DISAGGREGATED_PREFILL_PROXY_SCRIPT):
|
proxy_script = config.envs.get("DISAGGREGATED_PREFILL_PROXY_SCRIPT", \
|
||||||
|
'examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py')
|
||||||
|
with config.launch_server_proxy(proxy_script):
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
model=local_model_path,
|
model=local_model_path,
|
||||||
vllm_serve_args=config.server_cmd,
|
vllm_serve_args=config.server_cmd,
|
||||||
|
|||||||
Reference in New Issue
Block a user