[CI] improve disaggregation CI. (#11264)
Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -128,6 +128,10 @@ class Envs:
|
||||
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
||||
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
|
||||
|
||||
# Test: pd-disaggregation
|
||||
SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake")
|
||||
SGLANG_TEST_PD_DISAGG_DEVICES = EnvStr(None)
|
||||
|
||||
# Model Parallel
|
||||
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_with_error_check,
|
||||
)
|
||||
|
||||
@@ -27,6 +31,24 @@ class TestDisaggregationBase(CustomTestCase):
|
||||
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
|
||||
|
||||
# config transfer backend and rdma devices
|
||||
if is_in_ci():
|
||||
cls.transfer_backend = ["--disaggregation-transfer-backend", "mooncake"]
|
||||
cls.rdma_devices = ["--disaggregation-ib-device", get_rdma_devices_args()]
|
||||
else:
|
||||
cls.transfer_backend = [
|
||||
"--disaggregation-transfer-backend",
|
||||
envs.SGLANG_TEST_PD_DISAGG_BACKEND.get(),
|
||||
]
|
||||
cls.rdma_devices = [
|
||||
"--disaggregation-ib-device",
|
||||
envs.SGLANG_TEST_PD_DISAGG_DEVICES.get(),
|
||||
]
|
||||
if cls.rdma_devices[1] is None:
|
||||
cls.rdma_devices = []
|
||||
msg = "No RDMA devices specified for disaggregation test, using default settings."
|
||||
warnings.warn(msg)
|
||||
|
||||
@classmethod
|
||||
def launch_lb(cls):
|
||||
lb_command = [
|
||||
@@ -75,3 +97,44 @@ class TestDisaggregationBase(CustomTestCase):
|
||||
|
||||
# wait for 5 seconds
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def get_rdma_devices_args():
|
||||
# 1. Get visible GPU indices
|
||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
||||
if not cuda_visible_devices:
|
||||
warnings.warn("CUDA_VISIBLE_DEVICES is not set. Using default RDMA devices.")
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
|
||||
try:
|
||||
# Convert to list of integers (handling possible spaces and empty strings)
|
||||
gpu_indices = [
|
||||
int(idx.strip()) for idx in cuda_visible_devices.split(",") if idx.strip()
|
||||
]
|
||||
if not gpu_indices or len(gpu_indices) > 4:
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
except ValueError:
|
||||
warnings.warn(f"Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible_devices}")
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
|
||||
# 2. Calculate base RDMA index group (each group of 4 GPUs uses consecutive devices)
|
||||
base_rdma_group = min(gpu_indices) // 4 * 4
|
||||
|
||||
# 3. Generate RDMA device names
|
||||
rdma_devices = []
|
||||
for gpu_idx in gpu_indices:
|
||||
# Validate GPU index within expected range
|
||||
if gpu_idx < base_rdma_group or gpu_idx >= base_rdma_group + 4:
|
||||
warnings.warn(
|
||||
f"GPU index {gpu_idx} is outside expected group {base_rdma_group}-{base_rdma_group+3}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Map GPU index to RDMA device index
|
||||
rdma_index = base_rdma_group // 4 * 4 + (gpu_idx % 4)
|
||||
rdma_devices.append(f"mlx5_roce{rdma_index}")
|
||||
|
||||
if not rdma_devices:
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
|
||||
return ",".join(rdma_devices)
|
||||
|
||||
Reference in New Issue
Block a user