[ci]use H20 to run disaggregation test (#11543)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
@@ -15,6 +16,8 @@ from sglang.test.test_utils import (
|
||||
popen_with_error_check,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestDisaggregationBase(CustomTestCase):
|
||||
@classmethod
|
||||
@@ -100,11 +103,28 @@ class TestDisaggregationBase(CustomTestCase):
|
||||
|
||||
|
||||
def get_rdma_devices_args():
|
||||
def _parse_list_env(var_name: str):
|
||||
val = os.getenv(var_name)
|
||||
if not val:
|
||||
return None
|
||||
items = [x.strip() for x in val.split(",") if x.strip()]
|
||||
return items or None
|
||||
|
||||
def _pick_default_pair(rdma_all_devices):
|
||||
return [rdma_all_devices[0], rdma_all_devices[len(rdma_all_devices) // 2]]
|
||||
|
||||
rdma_all_devices = _parse_list_env("SGLANG_CI_RDMA_ALL_DEVICES") or [
|
||||
f"mlx5_roce{i}" for i in range(8)
|
||||
]
|
||||
logger.info("Resolved rdma_all_devices=%s", rdma_all_devices)
|
||||
|
||||
n_rdma = len(rdma_all_devices)
|
||||
|
||||
# 1. Get visible GPU indices
|
||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
||||
if not cuda_visible_devices:
|
||||
warnings.warn("CUDA_VISIBLE_DEVICES is not set. Using default RDMA devices.")
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
return ",".join(_pick_default_pair(rdma_all_devices))
|
||||
|
||||
try:
|
||||
# Convert to list of integers (handling possible spaces and empty strings)
|
||||
@@ -112,29 +132,27 @@ def get_rdma_devices_args():
|
||||
int(idx.strip()) for idx in cuda_visible_devices.split(",") if idx.strip()
|
||||
]
|
||||
if not gpu_indices or len(gpu_indices) > 4:
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
return ",".join(_pick_default_pair(rdma_all_devices))
|
||||
except ValueError:
|
||||
warnings.warn(f"Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible_devices}")
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
return ",".join(_pick_default_pair(rdma_all_devices))
|
||||
|
||||
# 2. Calculate base RDMA index group (each group of 4 GPUs uses consecutive devices)
|
||||
base_rdma_group = min(gpu_indices) // 4 * 4
|
||||
base_rdma_group = (min(gpu_indices) // 4) * 4
|
||||
for gpu_idx in gpu_indices:
|
||||
if not (base_rdma_group <= gpu_idx < base_rdma_group + 4):
|
||||
warnings.warn(
|
||||
f"GPU index {gpu_idx} is outside expected group "
|
||||
f"{base_rdma_group}-{base_rdma_group+3}"
|
||||
)
|
||||
|
||||
# 3. Generate RDMA device names
|
||||
rdma_devices = []
|
||||
for gpu_idx in gpu_indices:
|
||||
# Validate GPU index within expected range
|
||||
if gpu_idx < base_rdma_group or gpu_idx >= base_rdma_group + 4:
|
||||
warnings.warn(
|
||||
f"GPU index {gpu_idx} is outside expected group {base_rdma_group}-{base_rdma_group+3}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Map GPU index to RDMA device index
|
||||
rdma_index = base_rdma_group // 4 * 4 + (gpu_idx % 4)
|
||||
rdma_devices.append(f"mlx5_roce{rdma_index}")
|
||||
nic_index = gpu_idx // (8 // n_rdma)
|
||||
rdma_devices.append(rdma_all_devices[nic_index])
|
||||
|
||||
if not rdma_devices:
|
||||
return "mlx5_roce0,mlx5_roce4"
|
||||
return ",".join(_pick_default_pair(rdma_all_devices))
|
||||
|
||||
return ",".join(rdma_devices)
|
||||
|
||||
Reference in New Issue
Block a user