[1/N] Refactor nightly test structure (#5479)
### What this PR does / why we need it?
This patch is a series of refactoring actions, including clarifying the
directory structure of nightly tests, refactoring the config retrieval
logic, and optimizing the workflow, etc. This is the first step:
refactoring the directory structure of nightly to make it more readable
and logical.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
6
.github/workflows/nightly_test_a2.yaml
vendored
6
.github/workflows/nightly_test_a2.yaml
vendored
@@ -51,13 +51,13 @@ jobs:
|
|||||||
test_config:
|
test_config:
|
||||||
- name: qwen3-32b
|
- name: qwen3-32b
|
||||||
os: linux-aarch64-a2-4
|
os: linux-aarch64-a2-4
|
||||||
tests: tests/e2e/nightly/models/test_qwen3_32b.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen3_32b.py
|
||||||
- name: qwen3-32b-in8-a2
|
- name: qwen3-32b-in8-a2
|
||||||
os: linux-aarch64-a2-4
|
os: linux-aarch64-a2-4
|
||||||
tests: tests/e2e/nightly/models/test_qwen3_32b_int8.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen3_32b_int8.py
|
||||||
- name: test_custom_op
|
- name: test_custom_op
|
||||||
os: linux-aarch64-a2-1
|
os: linux-aarch64-a2-1
|
||||||
tests: tests/e2e/nightly/ops
|
tests: tests/e2e/nightly/single_node/ops/singlecard_ops
|
||||||
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
|
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
|
||||||
with:
|
with:
|
||||||
vllm: v0.13.0
|
vllm: v0.13.0
|
||||||
|
|||||||
42
.github/workflows/nightly_test_a3.yaml
vendored
42
.github/workflows/nightly_test_a3.yaml
vendored
@@ -56,15 +56,15 @@ jobs:
|
|||||||
- name: multi-node-qwen3-dp
|
- name: multi-node-qwen3-dp
|
||||||
config_file_path: Qwen3-235B-A22B.yaml
|
config_file_path: Qwen3-235B-A22B.yaml
|
||||||
size: 2
|
size: 2
|
||||||
- name: multi-node-dpsk-4node-pd
|
# - name: multi-node-dpsk-4node-pd
|
||||||
config_file_path: DeepSeek-R1-W8A8.yaml
|
# config_file_path: DeepSeek-R1-W8A8.yaml
|
||||||
size: 4
|
# size: 4
|
||||||
- name: multi-node-qwenw8a8-2node
|
- name: multi-node-qwenw8a8-2node
|
||||||
config_file_path: Qwen3-235B-W8A8.yaml
|
config_file_path: Qwen3-235B-W8A8.yaml
|
||||||
size: 2
|
size: 2
|
||||||
- name: multi-node-deepseek-r1-w8a8-eplb
|
# - name: multi-node-deepseek-r1-w8a8-eplb
|
||||||
config_file_path: DeepSeek-R1-W8A8-EPLB.yaml
|
# config_file_path: DeepSeek-R1-W8A8-EPLB.yaml
|
||||||
size: 4
|
# size: 4
|
||||||
- name: multi-node-qwenw8a8-2node-eplb
|
- name: multi-node-qwenw8a8-2node-eplb
|
||||||
config_file_path: Qwen3-235B-W8A8-EPLB.yaml
|
config_file_path: Qwen3-235B-W8A8-EPLB.yaml
|
||||||
size: 2
|
size: 2
|
||||||
@@ -89,47 +89,47 @@ jobs:
|
|||||||
test_config:
|
test_config:
|
||||||
- name: qwen3-32b-in8-a3
|
- name: qwen3-32b-in8-a3
|
||||||
os: linux-aarch64-a3-4
|
os: linux-aarch64-a3-4
|
||||||
tests: tests/e2e/nightly/models/test_qwen3_32b_int8.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen3_32b_int8.py
|
||||||
- name: qwen3-32b-int8-a3-feature-stack3
|
- name: qwen3-32b-int8-a3-feature-stack3
|
||||||
os: linux-aarch64-a3-4
|
os: linux-aarch64-a3-4
|
||||||
tests: tests/e2e/nightly/features/test_qwen3_32b_int8_a3_feature_stack3.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen3_32b_int8_a3_feature_stack3.py
|
||||||
- name: qwen3-235b-a22b-w8a8-eplb
|
- name: qwen3-235b-a22b-w8a8-eplb
|
||||||
os: linux-aarch64-a3-16
|
os: linux-aarch64-a3-16
|
||||||
tests: tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen3_235b_a22b_w8a8_eplb.py
|
||||||
- name: deepseek-r1-w8a8-eplb
|
- name: deepseek-r1-w8a8-eplb
|
||||||
os: linux-aarch64-a3-16
|
os: linux-aarch64-a3-16
|
||||||
tests: tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8_eplb.py
|
tests: tests/e2e/nightly/single_node/models/test_deepseek_r1_0528_w8a8_eplb.py
|
||||||
- name: deepseek-r1-w8a8-mtpx
|
- name: deepseek-r1-w8a8-mtpx
|
||||||
os: linux-aarch64-a3-16
|
os: linux-aarch64-a3-16
|
||||||
tests: tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py
|
tests: tests/e2e/nightly/single_node/models/test_mtpx_deepseek_r1_0528_w8a8.py
|
||||||
- name: qwen2-5-vl-7b
|
- name: qwen2-5-vl-7b
|
||||||
os: linux-aarch64-a3-4
|
os: linux-aarch64-a3-4
|
||||||
tests: tests/e2e/nightly/models/test_qwen2_5_vl_7b.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b.py
|
||||||
- name: qwen2-5-vl-32b
|
- name: qwen2-5-vl-32b
|
||||||
os: linux-aarch64-a3-4
|
os: linux-aarch64-a3-4
|
||||||
tests: tests/e2e/nightly/models/test_qwen2_5_vl_32b.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_32b.py
|
||||||
- name: qwen3-32b-int8-prefix-cache
|
- name: qwen3-32b-int8-prefix-cache
|
||||||
os: linux-aarch64-a3-4
|
os: linux-aarch64-a3-4
|
||||||
tests: tests/e2e/nightly/features/test_prefix_cache_qwen3_32b_int8.py
|
tests: tests/e2e/nightly/single_node/models/test_prefix_cache_qwen3_32b_int8.py
|
||||||
- name: deepseek-r1-0528-w8a8
|
- name: deepseek-r1-0528-w8a8
|
||||||
os: linux-aarch64-a3-16
|
os: linux-aarch64-a3-16
|
||||||
tests: tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py
|
tests: tests/e2e/nightly/single_node/models/test_deepseek_r1_0528_w8a8.py
|
||||||
- name: deepseek-r1-0528-w8a8-prefix-cache
|
- name: deepseek-r1-0528-w8a8-prefix-cache
|
||||||
os: linux-aarch64-a3-16
|
os: linux-aarch64-a3-16
|
||||||
tests: tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py
|
tests: tests/e2e/nightly/single_node/models/test_prefix_cache_deepseek_r1_0528_w8a8.py
|
||||||
- name: qwq-32b-a3
|
- name: qwq-32b-a3
|
||||||
os: linux-aarch64-a3-4
|
os: linux-aarch64-a3-4
|
||||||
tests: tests/e2e/nightly/models/test_qwq_32b.py
|
tests: tests/e2e/nightly/single_node/models/test_qwq_32b.py
|
||||||
- name: qwen3-30b-w8a8
|
- name: qwen3-30b-w8a8
|
||||||
os: linux-aarch64-a3-2
|
os: linux-aarch64-a3-2
|
||||||
tests: tests/e2e/nightly/models/test_qwen3_30b_w8a8.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen3_30b_w8a8.py
|
||||||
- name: qwen3-235b-w8a8
|
- name: qwen3-235b-w8a8
|
||||||
os: linux-aarch64-a3-16
|
os: linux-aarch64-a3-16
|
||||||
tests: tests/e2e/nightly/models/test_qwen3_235b_w8a8.py
|
tests: tests/e2e/nightly/single_node/models/test_qwen3_235b_w8a8.py
|
||||||
# TODO: Replace deepseek3.2-exp with deepseek3.2 after nightly tests pass
|
# TODO: Replace deepseek3.2-exp with deepseek3.2 after nightly tests pass
|
||||||
# - name: deepseek3_2-exp-w8a8
|
# - name: deepseek3_2-exp-w8a8
|
||||||
# os: linux-aarch64-a3-16
|
# os: linux-aarch64-a3-16
|
||||||
# tests: tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py
|
# tests: tests/e2e/nightly/single_node/models/test_deepseek_v3_2_exp_w8a8.py
|
||||||
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
|
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
|
||||||
with:
|
with:
|
||||||
vllm: v0.13.0
|
vllm: v0.13.0
|
||||||
@@ -148,7 +148,7 @@ jobs:
|
|||||||
test_config:
|
test_config:
|
||||||
- name: custom-op-dispatch_gmm_combine_decode
|
- name: custom-op-dispatch_gmm_combine_decode
|
||||||
os: linux-aarch64-a3-16
|
os: linux-aarch64-a3-16
|
||||||
tests: tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py
|
tests: tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py
|
||||||
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
|
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
|
||||||
with:
|
with:
|
||||||
runner: ${{ matrix.test_config.os }}
|
runner: ${{ matrix.test_config.os }}
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from typing import Any, Optional, Tuple, TypeVar, Union
|
from typing import Any, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import httpx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
@@ -52,7 +51,8 @@ from vllm.utils.network_utils import get_open_port
|
|||||||
|
|
||||||
from tests.e2e.model_utils import (TokensTextLogprobs,
|
from tests.e2e.model_utils import (TokensTextLogprobs,
|
||||||
TokensTextLogprobsPromptLogprobs)
|
TokensTextLogprobsPromptLogprobs)
|
||||||
from tests.e2e.nightly.multi_node.config.multi_node_config import NodeInfo
|
from tests.e2e.nightly.multi_node.scripts.multi_node_config import (
|
||||||
|
DisaggregatedPrefillCfg, NodeInfo)
|
||||||
from vllm_ascend.ascend_config import clear_ascend_config
|
from vllm_ascend.ascend_config import clear_ascend_config
|
||||||
# TODO: remove this part after the patch merged into vllm, if
|
# TODO: remove this part after the patch merged into vllm, if
|
||||||
# we not explicitly patch here, some of them might be effectiveless
|
# we not explicitly patch here, some of them might be effectiveless
|
||||||
@@ -104,6 +104,7 @@ class RemoteOpenAIServer:
|
|||||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||||
if env_dict is not None:
|
if env_dict is not None:
|
||||||
env.update(env_dict)
|
env.update(env_dict)
|
||||||
|
logger.info(f"Starting server with command: {' '.join(server_cmd)}")
|
||||||
self.proc: subprocess.Popen = subprocess.Popen(
|
self.proc: subprocess.Popen = subprocess.Popen(
|
||||||
server_cmd,
|
server_cmd,
|
||||||
env=env,
|
env=env,
|
||||||
@@ -111,20 +112,21 @@ class RemoteOpenAIServer:
|
|||||||
stderr=sys.stderr,
|
stderr=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
model: str,
|
self,
|
||||||
vllm_serve_args: Union[list[str], str],
|
model: str,
|
||||||
*,
|
vllm_serve_args: Union[list[str], str],
|
||||||
server_host: str = '0.0.0.0',
|
*,
|
||||||
server_port: int = 8080,
|
server_host: str = '0.0.0.0',
|
||||||
env_dict: Optional[dict[str, str]] = None,
|
server_port: int = 8080,
|
||||||
seed: Optional[int] = None,
|
env_dict: Optional[dict[str, str]] = None,
|
||||||
auto_port: bool = True,
|
seed: Optional[int] = None,
|
||||||
nodes_info: Optional[list[NodeInfo]] = None,
|
auto_port: bool = True,
|
||||||
disaggregated_prefill: Optional[dict] = None,
|
nodes_info: Optional[list[NodeInfo]] = None,
|
||||||
proxy_port: Optional[int] = None,
|
disaggregated_prefill: Optional[DisaggregatedPrefillCfg] = None,
|
||||||
max_wait_seconds: Optional[float] = None,
|
proxy_port: Optional[int] = None,
|
||||||
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
|
max_wait_seconds: Optional[float] = None,
|
||||||
|
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
|
||||||
if isinstance(vllm_serve_args, str):
|
if isinstance(vllm_serve_args, str):
|
||||||
vllm_serve_args = shlex.split(vllm_serve_args)
|
vllm_serve_args = shlex.split(vllm_serve_args)
|
||||||
else:
|
else:
|
||||||
@@ -187,6 +189,7 @@ class RemoteOpenAIServer:
|
|||||||
This is for headless mode, where the api server
|
This is for headless mode, where the api server
|
||||||
process only exists in the leader node.
|
process only exists in the leader node.
|
||||||
"""
|
"""
|
||||||
|
logger.info("Hanging until server process terminates...")
|
||||||
client = requests
|
client = requests
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@@ -198,8 +201,6 @@ class RemoteOpenAIServer:
|
|||||||
except Exception:
|
except Exception:
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
if isinstance(client, httpx.Client):
|
|
||||||
client.close()
|
|
||||||
self._terminate_server()
|
self._terminate_server()
|
||||||
|
|
||||||
def _wait_for_server_pd(self, timeout: float):
|
def _wait_for_server_pd(self, timeout: float):
|
||||||
@@ -210,8 +211,7 @@ class RemoteOpenAIServer:
|
|||||||
def url_health(ip: str, port: int) -> str:
|
def url_health(ip: str, port: int) -> str:
|
||||||
return f"http://{ip}:{port}/health"
|
return f"http://{ip}:{port}/health"
|
||||||
|
|
||||||
targets = [(node_info.ip,
|
targets = [(node_info.ip, url_health(node_info.ip, self.port))
|
||||||
url_health(node_info.ip, node_info.server_port))
|
|
||||||
for node_info in self.nodes_info if not node_info.headless]
|
for node_info in self.nodes_info if not node_info.headless]
|
||||||
|
|
||||||
# Wait for proxy ready
|
# Wait for proxy ready
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ env_common:
|
|||||||
disaggregated_prefill:
|
disaggregated_prefill:
|
||||||
enabled: true
|
enabled: true
|
||||||
prefiller_host_index: [0, 1]
|
prefiller_host_index: [0, 1]
|
||||||
decoder_host_index: [2]
|
decoder_host_index: [2, 3]
|
||||||
|
|
||||||
deployment:
|
deployment:
|
||||||
-
|
-
|
||||||
@@ -16,7 +16,7 @@ env_common:
|
|||||||
disaggregated_prefill:
|
disaggregated_prefill:
|
||||||
enabled: true
|
enabled: true
|
||||||
prefiller_host_index: [0, 1]
|
prefiller_host_index: [0, 1]
|
||||||
decoder_host_index: [2]
|
decoder_host_index: [2, 3]
|
||||||
|
|
||||||
deployment:
|
deployment:
|
||||||
-
|
-
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import regex as re
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from tests.e2e.nightly.multi_node.config.utils import (get_all_ipv4,
|
|
||||||
get_avaliable_port,
|
|
||||||
get_cluster_ips,
|
|
||||||
get_net_interface,
|
|
||||||
setup_logger)
|
|
||||||
|
|
||||||
setup_logger()
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
DISAGGEGATED_PREFILL_PORT = 5333
|
|
||||||
CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/models/"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeInfo:
|
|
||||||
index: int
|
|
||||||
ip: str
|
|
||||||
server_cmd: str
|
|
||||||
headless: bool
|
|
||||||
server_port: int
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return (f"NodeInfo:\n"
|
|
||||||
f" index={self.index}\n"
|
|
||||||
f" ip={self.ip}\n"
|
|
||||||
f" headless={self.headless}\n"
|
|
||||||
f" server_port={self.server_port}")
|
|
||||||
|
|
||||||
|
|
||||||
class MultiNodeConfig:
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model: str,
|
|
||||||
test_name: str,
|
|
||||||
nodes_info: list[NodeInfo],
|
|
||||||
npu_per_node: int = 16,
|
|
||||||
server_port: int = 8080,
|
|
||||||
disaggregated_prefill: Optional[dict] = None,
|
|
||||||
envs: Optional[dict] = None,
|
|
||||||
perf_cmd: Optional[str] = None,
|
|
||||||
acc_cmd: Optional[str] = None):
|
|
||||||
self.test_name = test_name
|
|
||||||
self.model = model
|
|
||||||
self.nodes_info = nodes_info
|
|
||||||
# We assume the first index of nodes as the master
|
|
||||||
# NOTE: this may be different in the scenarios like disaggregated prefill
|
|
||||||
# There may be multi groups of nodes, and the master of each group may be different
|
|
||||||
self.master_ip = self.nodes_info[0].ip
|
|
||||||
self.num_nodes = len(self.nodes_info)
|
|
||||||
self.npu_per_node = npu_per_node
|
|
||||||
self.server_port = server_port
|
|
||||||
self.envs = envs if envs is not None else {}
|
|
||||||
self.proxy_port = get_avaliable_port()
|
|
||||||
self.perf_cmd = perf_cmd
|
|
||||||
self.acc_cmd = acc_cmd
|
|
||||||
|
|
||||||
self.disaggregated_prefill = disaggregated_prefill
|
|
||||||
self._init_disaggregated_prefill()
|
|
||||||
|
|
||||||
self._init_dist_env()
|
|
||||||
self.server_cmd = self._expand_env_vars(self.node_info.server_cmd,
|
|
||||||
self.envs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cur_ip(self):
|
|
||||||
return self.nodes_info[self.cur_index].ip
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nic_name(self):
|
|
||||||
return get_net_interface(self.cur_ip)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node_info(self):
|
|
||||||
return self.nodes_info[self.cur_index]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cur_index(self):
|
|
||||||
# 1. Try to read worker index from K8s environment variable
|
|
||||||
worker_index = os.environ.get("LWS_WORKER_INDEX")
|
|
||||||
if worker_index:
|
|
||||||
return int(worker_index)
|
|
||||||
|
|
||||||
# 2. Fallback: match local IP against cluster IP list
|
|
||||||
cluster_ips = [node.ip for node in self.nodes_info]
|
|
||||||
cluster_ip_set = set(cluster_ips)
|
|
||||||
|
|
||||||
cur_ips = get_all_ipv4()
|
|
||||||
|
|
||||||
for ip in cur_ips:
|
|
||||||
if ip in cluster_ip_set:
|
|
||||||
return cluster_ips.index(ip)
|
|
||||||
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not determine current node index: no matching IP.\n"
|
|
||||||
f"Local machine IPs: {cur_ips}\n"
|
|
||||||
f"Cluster IPs: {cluster_ips}\n"
|
|
||||||
"Please check your config file or network settings.")
|
|
||||||
|
|
||||||
def _init_disaggregated_prefill(self):
|
|
||||||
if self.disaggregated_prefill:
|
|
||||||
decode_host_index = self.disaggregated_prefill.get(
|
|
||||||
"decoder_host_index")
|
|
||||||
if not decode_host_index:
|
|
||||||
raise RuntimeError("got empty decode_host_index")
|
|
||||||
self.decode_start_index: int = decode_host_index[0]
|
|
||||||
self.num_prefillers = self.decode_start_index
|
|
||||||
self.num_decoders = self.num_nodes - self.num_prefillers
|
|
||||||
|
|
||||||
def _init_dist_env(self):
|
|
||||||
self.envs["HCCL_IF_IP"] = self.cur_ip
|
|
||||||
self.envs["GLOO_SOCKET_IFNAME"] = self.nic_name
|
|
||||||
self.envs["TP_SOCKET_IFNAME"] = self.nic_name
|
|
||||||
self.envs["HCCL_SOCKET_IFNAME"] = self.nic_name
|
|
||||||
self.envs["LOCAL_IP"] = self.cur_ip
|
|
||||||
self.envs["NIC_NAME"] = self.nic_name
|
|
||||||
|
|
||||||
master_ip = self.master_ip
|
|
||||||
if self.disaggregated_prefill:
|
|
||||||
if self.cur_index < self.decode_start_index:
|
|
||||||
# For prefiller nodes, use the default master ip(index==0) as DP master
|
|
||||||
master_ip = self.master_ip
|
|
||||||
else:
|
|
||||||
# For decoder nodes, use the first decoder node as DP master
|
|
||||||
master_ip = self.nodes_info[self.decode_start_index].ip
|
|
||||||
|
|
||||||
self.envs["MASTER_IP"] = master_ip
|
|
||||||
ascend_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
|
|
||||||
self.envs[
|
|
||||||
"LD_LIBRARY_PATH"] = f"{ascend_path}:{self.envs.get('LD_LIBRARY_PATH', os.environ.get('LD_LIBRARY_PATH', ''))}"
|
|
||||||
|
|
||||||
# keep the envs keys and values as strings
|
|
||||||
str_envs = {k: str(v) for k, v in self.envs.items()}
|
|
||||||
self.envs.clear()
|
|
||||||
self.envs.update(str_envs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _expand_env_vars(cmd: str, env: dict) -> str:
|
|
||||||
"""Expand environment variables in the command string."""
|
|
||||||
cmd = str(cmd)
|
|
||||||
pattern = re.compile(r"\$(\w+)|\$\{(\w+)\}")
|
|
||||||
|
|
||||||
def replace_var(match):
|
|
||||||
var_name = match.group(1) or match.group(2)
|
|
||||||
return str(env.get(var_name, match.group(0)))
|
|
||||||
|
|
||||||
return pattern.sub(replace_var, cmd)
|
|
||||||
|
|
||||||
class _ProxyContext:
|
|
||||||
|
|
||||||
def __init__(self, outer, proxy_script):
|
|
||||||
self.outer = outer
|
|
||||||
self.proxy_script = proxy_script
|
|
||||||
self.process = None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
o = self.outer
|
|
||||||
if not o.disaggregated_prefill or not o.is_master:
|
|
||||||
logger.info(
|
|
||||||
"Disaggregated prefill not enabled or not master node, skipping proxy launch."
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
prefiller_indices = o.disaggregated_prefill["prefiller_host_index"]
|
|
||||||
decoder_indices = o.disaggregated_prefill["decoder_host_index"]
|
|
||||||
|
|
||||||
common_indices = set(prefiller_indices) & set(decoder_indices)
|
|
||||||
assert not common_indices, f"Common indices found: {common_indices}"
|
|
||||||
assert o.proxy_port is not None, "proxy_port must be set"
|
|
||||||
|
|
||||||
cluster_ips = [node.ip for node in o.nodes_info]
|
|
||||||
prefiller_ips = [cluster_ips[i] for i in prefiller_indices]
|
|
||||||
decoder_ips = [cluster_ips[i] for i in decoder_indices]
|
|
||||||
prefiller_ports_list = [str(o.server_port)] * len(prefiller_ips)
|
|
||||||
decoder_ports_list = [str(o.server_port)] * len(decoder_ips)
|
|
||||||
|
|
||||||
proxy_cmd = [
|
|
||||||
"python",
|
|
||||||
self.proxy_script,
|
|
||||||
"--host",
|
|
||||||
o.cur_ip,
|
|
||||||
"--port",
|
|
||||||
str(o.proxy_port),
|
|
||||||
"--prefiller-hosts",
|
|
||||||
*prefiller_ips,
|
|
||||||
"--prefiller-ports",
|
|
||||||
*prefiller_ports_list,
|
|
||||||
"--decoder-hosts",
|
|
||||||
*decoder_ips,
|
|
||||||
"--decoder-ports",
|
|
||||||
*decoder_ports_list,
|
|
||||||
]
|
|
||||||
|
|
||||||
env = os.environ.copy()
|
|
||||||
env.update(o.envs)
|
|
||||||
logger.info(f"Launching proxy: {' '.join(proxy_cmd)}")
|
|
||||||
|
|
||||||
self.process = subprocess.Popen(proxy_cmd, env=env)
|
|
||||||
o.proxy_process = self.process
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
if self.process:
|
|
||||||
logger.info("Terminating proxy server process...")
|
|
||||||
try:
|
|
||||||
self.process.terminate()
|
|
||||||
self.process.wait(timeout=5)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning(
|
|
||||||
"Proxy process did not terminate, killing it...")
|
|
||||||
self.process.kill()
|
|
||||||
logger.info("Proxy server process terminated.")
|
|
||||||
|
|
||||||
def launch_server_proxy(self, proxy_script: str):
|
|
||||||
"""Return a context manager that launches the proxy server if disaggregated prefill is enabled."""
|
|
||||||
return self._ProxyContext(self, proxy_script)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_yaml(cls, yaml_path: Optional[str] = None):
|
|
||||||
if not yaml_path:
|
|
||||||
yaml_path = os.getenv("CONFIG_YAML_PATH", "DeepSeek-V3.yaml")
|
|
||||||
yaml_path = os.path.join(CONFIG_BASE_PATH, yaml_path)
|
|
||||||
with open(yaml_path, 'r') as file:
|
|
||||||
config_data = yaml.safe_load(file)
|
|
||||||
test_name = config_data.get("test_name", "default_test")
|
|
||||||
model = config_data.get("model", "default_model")
|
|
||||||
envs = config_data.get("env_common", {})
|
|
||||||
num_nodes = config_data.get("num_nodes", 2)
|
|
||||||
npu_per_node = config_data.get("npu_per_node", 16)
|
|
||||||
disaggregated_prefill = config_data.get("disaggregated_prefill")
|
|
||||||
# If disaggregated_prefill is set, override server_port to an available port for proxy running
|
|
||||||
server_port = config_data.get("server_port", 8080)
|
|
||||||
|
|
||||||
deployments = config_data.get("deployment", [])
|
|
||||||
assert len(deployments) == num_nodes, \
|
|
||||||
f"Number of deployments ({len(deployments)}) must match num_nodes ({num_nodes})"
|
|
||||||
cluster_ips = config_data.get("cluster_hosts", None)
|
|
||||||
if cluster_ips:
|
|
||||||
assert len(cluster_ips) == num_nodes, \
|
|
||||||
"Must provide cluster_ips for all nodes if it is explicitly specified."
|
|
||||||
else:
|
|
||||||
logger.info("Resolving cluster IPs via DNS...")
|
|
||||||
cluster_ips = get_cluster_ips(num_nodes)
|
|
||||||
nodes_info = []
|
|
||||||
|
|
||||||
for index, deployment in enumerate(deployments):
|
|
||||||
# after assert len(deployments) == num_nodes, we can assume that this will must have a match
|
|
||||||
server_cmd = deployment.get("server_cmd", "")
|
|
||||||
headless = "--headless" in server_cmd
|
|
||||||
nodes_info.append(
|
|
||||||
NodeInfo(ip=cluster_ips[index],
|
|
||||||
index=index,
|
|
||||||
headless=headless,
|
|
||||||
server_port=server_port,
|
|
||||||
server_cmd=server_cmd))
|
|
||||||
|
|
||||||
benchmarks = config_data.get("benchmarks") or {}
|
|
||||||
assert benchmarks is not None, "benchmarks must be provided"
|
|
||||||
perf_cmd = benchmarks.get("perf")
|
|
||||||
acc_cmd = benchmarks.get("acc")
|
|
||||||
|
|
||||||
return cls(model=model,
|
|
||||||
test_name=test_name,
|
|
||||||
npu_per_node=npu_per_node,
|
|
||||||
envs=envs,
|
|
||||||
server_port=server_port,
|
|
||||||
disaggregated_prefill=disaggregated_prefill,
|
|
||||||
nodes_info=nodes_info,
|
|
||||||
perf_cmd=perf_cmd,
|
|
||||||
acc_cmd=acc_cmd)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def world_size(self):
|
|
||||||
return self.num_nodes * self.npu_per_node
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_master(self):
|
|
||||||
return self.cur_index == 0
|
|
||||||
349
tests/e2e/nightly/multi_node/scripts/multi_node_config.py
Normal file
349
tests/e2e/nightly/multi_node/scripts/multi_node_config.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# isort: off
|
||||||
|
from tests.e2e.nightly.multi_node.scripts.utils import (
|
||||||
|
CONFIG_BASE_PATH, DEFAULT_SERVER_PORT, get_all_ipv4, get_cluster_ips,
|
||||||
|
get_net_interface, setup_logger, get_avaliable_port)
|
||||||
|
# isort: on
|
||||||
|
setup_logger()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NodeInfo:
|
||||||
|
index: int
|
||||||
|
ip: str
|
||||||
|
server_cmd: str
|
||||||
|
envs: dict | None = None
|
||||||
|
headless: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.ip:
|
||||||
|
raise ValueError("NodeInfo.ip must not be empty")
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return ("NodeInfo(\n"
|
||||||
|
f" index={self.index},\n"
|
||||||
|
f" ip={self.ip},\n"
|
||||||
|
f" headless={self.headless},\n"
|
||||||
|
")")
|
||||||
|
|
||||||
|
|
||||||
|
class DisaggregatedPrefillCfg:
|
||||||
|
|
||||||
|
def __init__(self, raw_cfg: dict, num_nodes: int):
|
||||||
|
self.prefiller_indices: list[int] = raw_cfg.get(
|
||||||
|
"prefiller_host_index", [])
|
||||||
|
self.decoder_indices: list[int] = raw_cfg.get("decoder_host_index", [])
|
||||||
|
|
||||||
|
if not self.decoder_indices:
|
||||||
|
raise RuntimeError("decoder_host_index must be provided")
|
||||||
|
|
||||||
|
self._validate(num_nodes)
|
||||||
|
|
||||||
|
self.decode_start_index = self.decoder_indices[0]
|
||||||
|
self.num_prefillers = len(self.prefiller_indices)
|
||||||
|
self.num_decoders = len(self.decoder_indices)
|
||||||
|
|
||||||
|
def _validate(self, num_nodes: int):
|
||||||
|
overlap = set(self.prefiller_indices) & set(self.decoder_indices)
|
||||||
|
if overlap:
|
||||||
|
raise AssertionError(f"Prefiller and decoder overlap: {overlap}")
|
||||||
|
|
||||||
|
all_indices = self.prefiller_indices + self.decoder_indices
|
||||||
|
if any(i >= num_nodes for i in all_indices):
|
||||||
|
raise ValueError("Disaggregated prefill index out of range")
|
||||||
|
|
||||||
|
def is_prefiller(self, index: int) -> bool:
|
||||||
|
return index in self.prefiller_indices
|
||||||
|
|
||||||
|
def is_decoder(self, index: int) -> bool:
|
||||||
|
return index in self.decoder_indices
|
||||||
|
|
||||||
|
def master_ip_for_node(self, index: int, nodes: list[NodeInfo]) -> str:
|
||||||
|
if self.is_prefiller(index):
|
||||||
|
return nodes[0].ip
|
||||||
|
return nodes[self.decode_start_index].ip
|
||||||
|
|
||||||
|
|
||||||
|
class DistEnvBuilder:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cur_node: NodeInfo,
|
||||||
|
master_ip: str,
|
||||||
|
common_envs: dict,
|
||||||
|
):
|
||||||
|
self.cur_ip = cur_node.ip
|
||||||
|
self.nic_name = get_net_interface(self.cur_ip)
|
||||||
|
self.master_ip = master_ip
|
||||||
|
|
||||||
|
# envs
|
||||||
|
common_envs = common_envs
|
||||||
|
current_envs = cur_node.envs or {}
|
||||||
|
# Node-specific envs override common envs
|
||||||
|
self.base_envs = {**common_envs, **current_envs}
|
||||||
|
|
||||||
|
def build(self) -> dict:
|
||||||
|
envs = dict(self.base_envs)
|
||||||
|
|
||||||
|
envs.update({
|
||||||
|
"HCCL_IF_IP": self.cur_ip,
|
||||||
|
"HCCL_SOCKET_IFNAME": self.nic_name,
|
||||||
|
"GLOO_SOCKET_IFNAME": self.nic_name,
|
||||||
|
"TP_SOCKET_IFNAME": self.nic_name,
|
||||||
|
"LOCAL_IP": self.cur_ip,
|
||||||
|
"NIC_NAME": self.nic_name,
|
||||||
|
"MASTER_IP": self.master_ip,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {k: str(v) for k, v in envs.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyLauncher:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
nodes: list[NodeInfo],
|
||||||
|
envs: dict,
|
||||||
|
proxy_port: int,
|
||||||
|
cur_index: int,
|
||||||
|
disagg_cfg: DisaggregatedPrefillCfg | None = None,
|
||||||
|
):
|
||||||
|
self.nodes = nodes
|
||||||
|
self.cfg = disagg_cfg
|
||||||
|
self.server_port = envs.get("SERVER_PORT", DEFAULT_SERVER_PORT)
|
||||||
|
self.proxy_port = proxy_port
|
||||||
|
self.proxy_script = envs.get(
|
||||||
|
"DISAGGREGATED_PREFILL_PROXY_SCRIPT",
|
||||||
|
'examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py'
|
||||||
|
)
|
||||||
|
self.envs = envs
|
||||||
|
self.is_master = cur_index == 0
|
||||||
|
self.cur_ip = nodes[cur_index].ip
|
||||||
|
self.process: Optional[subprocess.Popen[bytes]] = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if not self.is_master or self.cfg is None:
|
||||||
|
logger.info("Not launching proxy on non-master node")
|
||||||
|
return self
|
||||||
|
prefiller_ips = [self.nodes[i].ip for i in self.cfg.prefiller_indices]
|
||||||
|
decoder_ips = [self.nodes[i].ip for i in self.cfg.decoder_indices]
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"python",
|
||||||
|
self.proxy_script,
|
||||||
|
"--host",
|
||||||
|
self.cur_ip,
|
||||||
|
"--port",
|
||||||
|
str(self.proxy_port),
|
||||||
|
"--prefiller-hosts",
|
||||||
|
*prefiller_ips,
|
||||||
|
"--prefiller-ports",
|
||||||
|
*[str(self.server_port)] * len(prefiller_ips),
|
||||||
|
"--decoder-hosts",
|
||||||
|
*decoder_ips,
|
||||||
|
"--decoder-ports",
|
||||||
|
*[str(self.server_port)] * len(decoder_ips),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Launching proxy: %s", " ".join(cmd))
|
||||||
|
self.process = subprocess.Popen(cmd, env={**os.environ, **self.envs})
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
if not self.process:
|
||||||
|
return
|
||||||
|
logger.info("Stopping proxy server...")
|
||||||
|
self.process.terminate()
|
||||||
|
try:
|
||||||
|
self.process.wait(timeout=5)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
self.process.kill()
|
||||||
|
|
||||||
|
|
||||||
|
class MultiNodeConfig:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
test_name: str,
|
||||||
|
nodes: list[NodeInfo],
|
||||||
|
npu_per_node: int,
|
||||||
|
envs: dict,
|
||||||
|
disaggregated_prefill: dict | None,
|
||||||
|
perf_cmd: str | None,
|
||||||
|
acc_cmd: str | None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.test_name = test_name
|
||||||
|
self.nodes = nodes
|
||||||
|
self.npu_per_node = npu_per_node
|
||||||
|
self.perf_cmd = perf_cmd
|
||||||
|
self.acc_cmd = acc_cmd
|
||||||
|
|
||||||
|
self.cur_index = self._resolve_cur_index()
|
||||||
|
self.cur_node = self.nodes[self.cur_index]
|
||||||
|
|
||||||
|
self.disagg_cfg = (DisaggregatedPrefillCfg(disaggregated_prefill,
|
||||||
|
len(nodes))
|
||||||
|
if disaggregated_prefill else None)
|
||||||
|
|
||||||
|
master_ip = (self.disagg_cfg.master_ip_for_node(
|
||||||
|
self.cur_index, self.nodes)
|
||||||
|
if self.disagg_cfg else self.nodes[0].ip)
|
||||||
|
self.proxy_port = get_avaliable_port()
|
||||||
|
|
||||||
|
self.envs = DistEnvBuilder(
|
||||||
|
cur_node=self.cur_node,
|
||||||
|
master_ip=master_ip,
|
||||||
|
common_envs=envs,
|
||||||
|
).build()
|
||||||
|
logger.info("Node %d envs: %s", self.cur_index, self.envs)
|
||||||
|
|
||||||
|
self.server_cmd = self._expand_env(self.cur_node.server_cmd)
|
||||||
|
|
||||||
|
def _resolve_cur_index(self) -> int:
|
||||||
|
if (idx := os.environ.get("LWS_WORKER_INDEX")):
|
||||||
|
return int(idx)
|
||||||
|
|
||||||
|
local_ips = get_all_ipv4()
|
||||||
|
for i, node in enumerate(self.nodes):
|
||||||
|
if node.ip in local_ips:
|
||||||
|
return i
|
||||||
|
|
||||||
|
raise RuntimeError("Unable to determine current node index")
|
||||||
|
|
||||||
|
def _expand_env(self, cmd: str) -> str:
|
||||||
|
pattern = re.compile(r"\$(\w+)|\$\{(\w+)\}")
|
||||||
|
|
||||||
|
def repl(m):
|
||||||
|
key = m.group(1) or m.group(2)
|
||||||
|
return self.envs.get(key, m.group(0))
|
||||||
|
|
||||||
|
return pattern.sub(repl, cmd)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return len(self.nodes) * self.npu_per_node
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_master(self) -> bool:
|
||||||
|
return self.cur_index == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_port(self) -> int:
|
||||||
|
return self.envs.get("SERVER_PORT", DEFAULT_SERVER_PORT)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def master_ip(self) -> str:
|
||||||
|
return self.nodes[0].ip
|
||||||
|
|
||||||
|
@property
|
||||||
|
def benchmark_endpoint(self) -> tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Endpoint used by benchmark clients.
|
||||||
|
"""
|
||||||
|
master_ip = self.nodes[0].ip
|
||||||
|
server_port = self.envs.get("SERVER_PORT", DEFAULT_SERVER_PORT)
|
||||||
|
if self.disagg_cfg:
|
||||||
|
return master_ip, self.proxy_port
|
||||||
|
return master_ip, server_port
|
||||||
|
|
||||||
|
|
||||||
|
class MultiNodeConfigLoader:
|
||||||
|
"""Load MultiNodeConfig from yaml file."""
|
||||||
|
|
||||||
|
DEFAULT_CONFIG_NAME = "DeepSeek-V3.yaml"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, yaml_path: Optional[str] = None) -> MultiNodeConfig:
|
||||||
|
config = cls._load_yaml(yaml_path)
|
||||||
|
cls._validate_root(config)
|
||||||
|
|
||||||
|
nodes = cls._parse_nodes(config)
|
||||||
|
benchmarks = cls._parse_benchmarks(config)
|
||||||
|
|
||||||
|
return MultiNodeConfig(
|
||||||
|
model=config["model"],
|
||||||
|
test_name=config.get("test_name", "untitled_test"),
|
||||||
|
nodes=nodes,
|
||||||
|
npu_per_node=config.get("npu_per_node", 16),
|
||||||
|
envs=config.get("env_common", {}),
|
||||||
|
disaggregated_prefill=config.get("disaggregated_prefill"),
|
||||||
|
perf_cmd=benchmarks.get("perf"),
|
||||||
|
acc_cmd=benchmarks.get("acc"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_yaml(cls, yaml_path: Optional[str]) -> dict:
|
||||||
|
if not yaml_path:
|
||||||
|
yaml_path = os.getenv("CONFIG_YAML_PATH", cls.DEFAULT_CONFIG_NAME)
|
||||||
|
|
||||||
|
full_path = os.path.join(CONFIG_BASE_PATH, yaml_path)
|
||||||
|
logger.info("Loading config yaml: %s", full_path)
|
||||||
|
|
||||||
|
with open(full_path, "r") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_root(cfg: dict):
|
||||||
|
required = [
|
||||||
|
"model", "deployment", "num_nodes", "npu_per_node", "env_common",
|
||||||
|
"benchmarks"
|
||||||
|
]
|
||||||
|
missing = [k for k in required if k not in cfg]
|
||||||
|
if missing:
|
||||||
|
raise KeyError(f"Missing required config fields: {missing}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_nodes(cls, cfg: dict) -> list[NodeInfo]:
|
||||||
|
num_nodes = cfg["num_nodes"]
|
||||||
|
deployments = cfg["deployment"]
|
||||||
|
|
||||||
|
if len(deployments) != num_nodes:
|
||||||
|
raise AssertionError(
|
||||||
|
f"deployment size ({len(deployments)}) != num_nodes ({num_nodes})"
|
||||||
|
)
|
||||||
|
|
||||||
|
cluster_ips = cls._resolve_cluster_ips(cfg, num_nodes)
|
||||||
|
|
||||||
|
nodes: list[NodeInfo] = []
|
||||||
|
for idx, deploy in enumerate(deployments):
|
||||||
|
cmd = deploy.get("server_cmd", "")
|
||||||
|
envs = deploy.get("envs", {})
|
||||||
|
nodes.append(
|
||||||
|
NodeInfo(
|
||||||
|
index=idx,
|
||||||
|
ip=cluster_ips[idx],
|
||||||
|
server_cmd=cmd,
|
||||||
|
envs=envs,
|
||||||
|
headless="--headless" in cmd,
|
||||||
|
))
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_benchmarks(cfg: dict) -> dict:
|
||||||
|
benchmarks = cfg.get("benchmarks") or {}
|
||||||
|
return benchmarks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_cluster_ips(cfg: dict, num_nodes: int) -> list[str]:
|
||||||
|
if "cluster_hosts" in cfg and cfg["cluster_hosts"]:
|
||||||
|
ips = cfg["cluster_hosts"]
|
||||||
|
if len(ips) != num_nodes:
|
||||||
|
raise AssertionError("cluster_hosts size mismatch")
|
||||||
|
return ips
|
||||||
|
|
||||||
|
logger.info("Resolving cluster IPs via DNS...")
|
||||||
|
return get_cluster_ips(num_nodes)
|
||||||
@@ -9,11 +9,15 @@ RED="\033[0;31m"
|
|||||||
NC="\033[0m" # No Color
|
NC="\033[0m" # No Color
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
LOG_DIR="/root/.cache/tests/logs"
|
|
||||||
OVERWRITE_LOGS=true
|
|
||||||
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
|
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
|
||||||
|
# Home path for aisbench
|
||||||
export BENCHMARK_HOME=${WORKSPACE}/vllm-ascend/benchmark
|
export BENCHMARK_HOME=${WORKSPACE}/vllm-ascend/benchmark
|
||||||
|
|
||||||
|
# Logging configurations
|
||||||
export VLLM_LOGGING_LEVEL="INFO"
|
export VLLM_LOGGING_LEVEL="INFO"
|
||||||
|
# Reduce glog verbosity for mooncake
|
||||||
|
export GLOG_minloglevel=1
|
||||||
|
# Set transformers to offline mode to avoid downloading models during tests
|
||||||
export TRANSFORMERS_OFFLINE="1"
|
export TRANSFORMERS_OFFLINE="1"
|
||||||
|
|
||||||
# Function to print section headers
|
# Function to print section headers
|
||||||
@@ -131,7 +135,7 @@ kill_npu_processes() {
|
|||||||
run_tests_with_log() {
|
run_tests_with_log() {
|
||||||
set +e
|
set +e
|
||||||
kill_npu_processes
|
kill_npu_processes
|
||||||
pytest -sv --show-capture=no tests/e2e/nightly/multi_node/test_multi_node.py
|
pytest -sv --show-capture=no tests/e2e/nightly/multi_node/scripts/test_multi_node.py
|
||||||
ret=$?
|
ret=$?
|
||||||
set -e
|
set -e
|
||||||
if [ "$LWS_WORKER_INDEX" -eq 0 ]; then
|
if [ "$LWS_WORKER_INDEX" -eq 0 ]; then
|
||||||
|
|||||||
46
tests/e2e/nightly/multi_node/scripts/test_multi_node.py
Normal file
46
tests/e2e/nightly/multi_node/scripts/test_multi_node.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.e2e.conftest import RemoteOpenAIServer
|
||||||
|
from tests.e2e.nightly.multi_node.scripts.multi_node_config import (
|
||||||
|
MultiNodeConfigLoader, ProxyLauncher)
|
||||||
|
from tools.aisbench import run_aisbench_cases
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_node() -> None:
|
||||||
|
config = MultiNodeConfigLoader.from_yaml()
|
||||||
|
|
||||||
|
with ProxyLauncher(
|
||||||
|
nodes=config.nodes,
|
||||||
|
disagg_cfg=config.disagg_cfg,
|
||||||
|
envs=config.envs,
|
||||||
|
proxy_port=config.proxy_port,
|
||||||
|
cur_index=config.cur_index,
|
||||||
|
) as proxy:
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
model=config.model,
|
||||||
|
vllm_serve_args=config.server_cmd,
|
||||||
|
server_port=config.server_port,
|
||||||
|
server_host=config.master_ip,
|
||||||
|
env_dict=config.envs,
|
||||||
|
auto_port=False,
|
||||||
|
proxy_port=proxy.proxy_port,
|
||||||
|
disaggregated_prefill=config.disagg_cfg,
|
||||||
|
nodes_info=config.nodes,
|
||||||
|
max_wait_seconds=2800,
|
||||||
|
) as server:
|
||||||
|
|
||||||
|
host, port = config.benchmark_endpoint
|
||||||
|
|
||||||
|
if config.is_master:
|
||||||
|
run_aisbench_cases(
|
||||||
|
model=config.model,
|
||||||
|
port=port,
|
||||||
|
aisbench_cases=[config.acc_cmd, config.perf_cmd],
|
||||||
|
host_ip=host,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We should keep listening on the master node's server url determining when to exit.
|
||||||
|
server.hang_until_terminated(
|
||||||
|
f"http://{host}:{config.server_port}/health")
|
||||||
@@ -3,10 +3,14 @@ import os
|
|||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
|
DISAGGEGATED_PREFILL_PORT = 5333
|
||||||
|
CONFIG_BASE_PATH = "tests/e2e/nightly/multi_node/config/"
|
||||||
|
DEFAULT_SERVER_PORT = 8080
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def temp_env(env_dict):
|
def temp_env(env_dict):
|
||||||
@@ -41,13 +45,29 @@ def dns_resolver(retries: int = 240, base_delay: float = 0.5):
|
|||||||
return resolve
|
return resolve
|
||||||
|
|
||||||
|
|
||||||
def get_cluster_dns_list(word_size: int) -> list[str]:
|
def get_cluster_dns_list(world_size: int) -> List[str]:
|
||||||
|
if world_size < 1:
|
||||||
|
raise ValueError(f"world_size must be >= 1, got {world_size}")
|
||||||
|
|
||||||
leader_dns = os.getenv("LWS_LEADER_ADDRESS")
|
leader_dns = os.getenv("LWS_LEADER_ADDRESS")
|
||||||
if not leader_dns:
|
if not leader_dns:
|
||||||
raise RuntimeError("LWS_LEADER_ADDRESS is not set")
|
raise RuntimeError(
|
||||||
|
"environment variable LWS_LEADER_ADDRESS is not set")
|
||||||
|
|
||||||
workers = [f"vllm-0-{i}.vllm.vllm-project" for i in range(1, word_size)]
|
# Expected format:
|
||||||
return [leader_dns] + workers
|
# <leader-name>.<group-name>.<namespace>
|
||||||
|
parts = leader_dns.split(".")
|
||||||
|
if len(parts) < 3:
|
||||||
|
raise ValueError(f"invalid leader DNS format: {leader_dns}")
|
||||||
|
|
||||||
|
leader_name, group_name, namespace = parts[0], parts[1], parts[2]
|
||||||
|
|
||||||
|
worker_dns_list = [
|
||||||
|
f"{leader_name}-{idx}.{group_name}.{namespace}"
|
||||||
|
for idx in range(1, world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
return [leader_dns, *worker_dns_list]
|
||||||
|
|
||||||
|
|
||||||
def get_cluster_ips(word_size: int = 2) -> list[str]:
|
def get_cluster_ips(word_size: int = 2) -> list[str]:
|
||||||
@@ -92,7 +112,7 @@ def get_cur_ip(retries: int = 20, base_delay: float = 0.5):
|
|||||||
delay = min(delay * 1.5, 5)
|
delay = min(delay * 1.5, 5)
|
||||||
|
|
||||||
|
|
||||||
def get_net_interface(ip: Optional[str] = None) -> Optional[str]:
|
def get_net_interface(ip: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Returns specified IP's inetwork interface.
|
Returns specified IP's inetwork interface.
|
||||||
If no IP is provided, uses the first from hostname -I.
|
If no IP is provided, uses the first from hostname -I.
|
||||||
@@ -104,7 +124,7 @@ def get_net_interface(ip: Optional[str] = None) -> Optional[str]:
|
|||||||
for addr in addrs:
|
for addr in addrs:
|
||||||
if addr.family == socket.AF_INET and addr.address == ip:
|
if addr.family == socket.AF_INET and addr.address == ip:
|
||||||
return iface
|
return iface
|
||||||
return None
|
raise RuntimeError(f"No network interface found for IP {ip}")
|
||||||
|
|
||||||
|
|
||||||
def get_all_ipv4():
|
def get_all_ipv4():
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
from modelscope import snapshot_download # type: ignore
|
|
||||||
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
|
||||||
|
|
||||||
from tests.e2e.conftest import RemoteOpenAIServer
|
|
||||||
from tests.e2e.nightly.multi_node.config.multi_node_config import \
|
|
||||||
MultiNodeConfig
|
|
||||||
from tools.aisbench import run_aisbench_cases
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"San Francisco is a",
|
|
||||||
]
|
|
||||||
|
|
||||||
api_keyword_args = {
|
|
||||||
"max_tokens": 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_model_path_with_retry(
|
|
||||||
model: str,
|
|
||||||
revision: str = "master",
|
|
||||||
max_retries: int = 5,
|
|
||||||
delay: int = 5,
|
|
||||||
) -> Optional[str]:
|
|
||||||
for attempt in range(1, max_retries + 1):
|
|
||||||
try:
|
|
||||||
local_model_path = snapshot_download(
|
|
||||||
model_id=model,
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
return local_model_path
|
|
||||||
|
|
||||||
except HTTPError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except (ConnectionError, Timeout):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if attempt < max_retries:
|
|
||||||
time.sleep(delay)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_completions(url: str, model: str, prompts: Union[str, List[str]],
|
|
||||||
**api_kwargs: Any) -> List[str]:
|
|
||||||
"""
|
|
||||||
Asynchronously send HTTP requests to endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: Full endpoint URL, e.g. "http://localhost:1025/v1/completions"
|
|
||||||
model: Model name or local model path
|
|
||||||
prompts: A single prompt string or a list of prompts
|
|
||||||
**api_kwargs: Additional parameters (e.g., max_tokens, temperature)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: A list of generated texts corresponding to each prompt
|
|
||||||
"""
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
if isinstance(prompts, str):
|
|
||||||
prompts = [prompts]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
||||||
for prompt in prompts:
|
|
||||||
payload = {"model": model, "prompt": prompt, **api_kwargs}
|
|
||||||
|
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Request failed with status {response.status_code}: {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resp_json = response.json()
|
|
||||||
choices = resp_json.get("choices", [])
|
|
||||||
if not choices or not choices[0].get("text"):
|
|
||||||
raise ValueError("Empty response from server")
|
|
||||||
|
|
||||||
results.append(choices[0]["text"])
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multi_node() -> None:
|
|
||||||
config = MultiNodeConfig.from_yaml()
|
|
||||||
# To avoid modelscope 400 HttpError, we should download the model with retry
|
|
||||||
local_model_path = get_local_model_path_with_retry(config.model)
|
|
||||||
config.server_cmd = config.server_cmd.replace(config.model,
|
|
||||||
local_model_path)
|
|
||||||
assert local_model_path is not None, "can not find any local weight for test"
|
|
||||||
env_dict = config.envs
|
|
||||||
perf_cmd = config.perf_cmd
|
|
||||||
acc_cmd = config.acc_cmd
|
|
||||||
nodes_info = config.nodes_info
|
|
||||||
disaggregated_prefill = config.disaggregated_prefill
|
|
||||||
server_port = config.server_port
|
|
||||||
proxy_port = config.proxy_port
|
|
||||||
server_host = config.master_ip
|
|
||||||
proxy_script = config.envs.get("DISAGGREGATED_PREFILL_PROXY_SCRIPT", \
|
|
||||||
'examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py')
|
|
||||||
with config.launch_server_proxy(proxy_script):
|
|
||||||
with RemoteOpenAIServer(
|
|
||||||
model=local_model_path,
|
|
||||||
vllm_serve_args=config.server_cmd,
|
|
||||||
server_port=server_port,
|
|
||||||
server_host=server_host,
|
|
||||||
env_dict=env_dict,
|
|
||||||
auto_port=False,
|
|
||||||
proxy_port=proxy_port,
|
|
||||||
disaggregated_prefill=disaggregated_prefill,
|
|
||||||
nodes_info=nodes_info,
|
|
||||||
max_wait_seconds=2800,
|
|
||||||
) as remote_server:
|
|
||||||
if config.is_master:
|
|
||||||
port = proxy_port if disaggregated_prefill else server_port
|
|
||||||
# aisbench test
|
|
||||||
aisbench_cases = [acc_cmd, perf_cmd]
|
|
||||||
run_aisbench_cases(local_model_path,
|
|
||||||
port,
|
|
||||||
aisbench_cases,
|
|
||||||
host_ip=config.master_ip)
|
|
||||||
else:
|
|
||||||
# for the nodes except master, should hang until the task complete
|
|
||||||
master_url = f"http://{config.master_ip}:{server_port}/health"
|
|
||||||
remote_server.hang_until_terminated(master_url)
|
|
||||||
Reference in New Issue
Block a user