diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 01298a753..aa9814df9 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -128,6 +128,10 @@ class Envs: SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial") SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp") + # Test: pd-disaggregation + SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake") + SGLANG_TEST_PD_DISAGG_DEVICES = EnvStr(None) + # Model Parallel SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True) diff --git a/python/sglang/test/test_disaggregation_utils.py b/python/sglang/test/test_disaggregation_utils.py index 5c4601601..e8084f802 100644 --- a/python/sglang/test/test_disaggregation_utils.py +++ b/python/sglang/test/test_disaggregation_utils.py @@ -1,13 +1,17 @@ +import os import time +import warnings from urllib.parse import urlparse import requests +from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_ci, popen_with_error_check, ) @@ -27,6 +31,24 @@ class TestDisaggregationBase(CustomTestCase): print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None + # config transfer backend and rdma devices + if is_in_ci(): + cls.transfer_backend = ["--disaggregation-transfer-backend", "mooncake"] + cls.rdma_devices = ["--disaggregation-ib-device", get_rdma_devices_args()] + else: + cls.transfer_backend = [ + "--disaggregation-transfer-backend", + envs.SGLANG_TEST_PD_DISAGG_BACKEND.get(), + ] + cls.rdma_devices = [ + "--disaggregation-ib-device", + envs.SGLANG_TEST_PD_DISAGG_DEVICES.get(), + ] + if cls.rdma_devices[1] is None: + cls.rdma_devices = [] + msg = "No RDMA devices specified for disaggregation test, using default settings." + warnings.warn(msg) + @classmethod def launch_lb(cls): lb_command = [ @@ -75,3 +97,44 @@ class TestDisaggregationBase(CustomTestCase): # wait for 5 seconds time.sleep(5) + + +def get_rdma_devices_args(): + # 1. Get visible GPU indices + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if not cuda_visible_devices: + warnings.warn("CUDA_VISIBLE_DEVICES is not set. Using default RDMA devices.") + return "mlx5_roce0,mlx5_roce4" + + try: + # Convert to list of integers (handling possible spaces and empty strings) + gpu_indices = [ + int(idx.strip()) for idx in cuda_visible_devices.split(",") if idx.strip() + ] + if not gpu_indices or len(gpu_indices) > 4: + return "mlx5_roce0,mlx5_roce4" + except ValueError: + warnings.warn(f"Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible_devices}") + return "mlx5_roce0,mlx5_roce4" + + # 2. Calculate base RDMA index group (each group of 4 GPUs uses consecutive devices) + base_rdma_group = min(gpu_indices) // 4 * 4 + + # 3. Generate RDMA device names + rdma_devices = [] + for gpu_idx in gpu_indices: + # Validate GPU index within expected range + if gpu_idx < base_rdma_group or gpu_idx >= base_rdma_group + 4: + warnings.warn( + f"GPU index {gpu_idx} is outside expected group {base_rdma_group}-{base_rdma_group+3}" + ) + continue + + # Map GPU index to RDMA device index + rdma_index = base_rdma_group // 4 * 4 + (gpu_idx % 4) + rdma_devices.append(f"mlx5_roce{rdma_index}") + + if not rdma_devices: + return "mlx5_roce0,mlx5_roce4" + + return ",".join(rdma_devices) diff --git a/test/srt/hicache/test_disaggregation_hicache.py b/test/srt/hicache/test_disaggregation_hicache.py index 1b4015054..797393f7c 100644 --- a/test/srt/hicache/test_disaggregation_hicache.py +++ b/test/srt/hicache/test_disaggregation_hicache.py @@ -70,11 +70,8 @@ class DisaggregationHiCacheBase(TestDisaggregationBase): "wait_complete", "--mem-fraction-static", "0.8", - "--disaggregation-ib-device", - "mlx5_roce0", - "--disaggregation-transfer-backend", - "mooncake", ] + prefill_args += cls.transfer_backend + cls.rdma_devices env = { **os.environ, "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, @@ -148,11 +145,8 @@ class TestDisaggregationPrefillWithHiCache(DisaggregationHiCacheBase): "0.8", "--base-gpu-id", "1", - "--disaggregation-ib-device", - "mlx5_roce0", - "--disaggregation-transfer-backend", - "mooncake", ] + decode_args += cls.transfer_backend + cls.rdma_devices env = { **os.environ, "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, @@ -201,10 +195,6 @@ class TestDisaggregationDecodeWithHiCache(DisaggregationHiCacheBase): "0.8", "--base-gpu-id", "1", - "--disaggregation-ib-device", - "mlx5_roce0", - "--disaggregation-transfer-backend", - "mooncake", "--disaggregation-decode-enable-offload-kvcache", "--hicache-ratio", "1.2", @@ -215,6 +205,7 @@ class TestDisaggregationDecodeWithHiCache(DisaggregationHiCacheBase): "--hicache-storage-prefetch-policy", "wait_complete", ] + decode_args += cls.transfer_backend + cls.rdma_devices env = { **os.environ, "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, diff --git a/test/srt/hicache/test_hicache_storage_mooncake_backend.py b/test/srt/hicache/test_hicache_storage_mooncake_backend.py index 631a47652..657fc9680 100644 --- a/test/srt/hicache/test_hicache_storage_mooncake_backend.py +++ b/test/srt/hicache/test_hicache_storage_mooncake_backend.py @@ -15,6 +15,7 @@ import requests from test_hicache_storage_file_backend import HiCacheStorageBaseMixin from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, CustomTestCase, @@ -192,7 +193,7 @@ class HiCacheStorageMooncakeBackendBaseMixin(HiCacheStorageBaseMixin): """Get additional server arguments specific to configuration - override in subclasses""" server_args = { - "--tp-size": 1, + "--tp-size": 2, "--hicache-ratio": 2, "--hicache-storage-backend": "mooncake", } @@ -202,7 +203,7 @@ class HiCacheStorageMooncakeBackendBaseMixin(HiCacheStorageBaseMixin): "MOONCAKE_MASTER": f"127.0.0.1:{cls.mooncake_master_port}", "MOONCAKE_PROTOCOL": "rdma", "MC_MS_AUTO_DISC": "0", - "MOONCAKE_DEVICE": "mlx5_roce0,mlx5_roce1", + "MOONCAKE_DEVICE": get_rdma_devices_args(), "MOONCAKE_TE_META_DATA_SERVER": f"http://127.0.0.1:{cls.mooncake_metadata_port}/metadata", "MOONCAKE_GLOBAL_SEGMENT_SIZE": "4294967296", # 4 GiB } diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 186f7c260..7eb82e36e 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -134,11 +134,13 @@ suites = { TestFile("lora/test_lora_tp.py", 116), TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("test_data_parallelism.py", 73), + TestFile("test_disaggregation.py", 499), TestFile("test_dp_attention.py", 594), TestFile("test_load_weights_from_remote_instance.py", 72), TestFile("test_patch_torch.py", 19), TestFile("test_release_memory_occupation.py", 257), TestFile("hicache/test_hicache_storage_file_backend.py", 200), + TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400), TestFile("hicache/test_hicache_storage_3fs_backend.py", 200), ], "per-commit-4-gpu": [ @@ -149,9 +151,7 @@ suites = { TestFile("test_multi_instance_release_memory_occupation.py", 64), ], "per-commit-8-gpu": [ - TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400), TestFile("lora/test_lora_llama4.py", 400), - TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation_dp_attention.py", 155), TestFile("test_disaggregation_different_tp.py", 600), TestFile("test_disaggregation_pp.py", 140), diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index 9fecf5c59..06cf0e9e6 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -40,10 +40,9 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "2", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", + "1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -58,12 +57,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "2", + "1", "--base-gpu-id", - "2", - "--disaggregation-ib-device", - "mlx5_roce2,mlx5_roce3", + "1", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, @@ -171,10 +169,9 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "2", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", + "1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -189,12 +186,11 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "2", + "1", "--base-gpu-id", - "2", - "--disaggregation-ib-device", - "mlx5_roce2,mlx5_roce3", + "1", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, @@ -270,10 +266,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "2", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", + "1", ] + cls.spec_args + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -288,12 +283,11 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "2", + "1", "--base-gpu-id", - "2", - "--disaggregation-ib-device", - "mlx5_roce2,mlx5_roce3", + "1", ] + cls.spec_args + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, @@ -346,10 +340,9 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "2", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", + "1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -364,12 +357,11 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "2", + "1", "--base-gpu-id", - "2", - "--disaggregation-ib-device", - "mlx5_roce2,mlx5_roce3", + "1", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, diff --git a/test/srt/test_disaggregation_different_tp.py b/test/srt/test_disaggregation_different_tp.py index 3fd00c217..9664d7cec 100644 --- a/test/srt/test_disaggregation_different_tp.py +++ b/test/srt/test_disaggregation_different_tp.py @@ -41,9 +41,8 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): "prefill", "--tp", "4", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -61,9 +60,8 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): "2", "--base-gpu-id", "4", - "--disaggregation-ib-device", - "mlx5_roce4,mlx5_roce5", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, @@ -115,9 +113,8 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): "prefill", "--tp", "2", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -135,9 +132,8 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): "4", "--base-gpu-id", "4", - "--disaggregation-ib-device", - "mlx5_roce4,mlx5_roce5", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, @@ -189,9 +185,8 @@ class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase): "prefill", "--tp", "4", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -209,9 +204,8 @@ class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase): "2", "--base-gpu-id", "4", - "--disaggregation-ib-device", - "mlx5_roce4,mlx5_roce5", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, @@ -263,9 +257,8 @@ class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase): "prefill", "--tp", "2", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -283,9 +276,8 @@ class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase): "4", "--base-gpu-id", "4", - "--disaggregation-ib-device", - "mlx5_roce4,mlx5_roce5", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, diff --git a/test/srt/test_disaggregation_dp_attention.py b/test/srt/test_disaggregation_dp_attention.py index bf934a913..45a39b0d3 100644 --- a/test/srt/test_disaggregation_dp_attention.py +++ b/test/srt/test_disaggregation_dp_attention.py @@ -45,9 +45,8 @@ class TestDisaggregationDPAttention(TestDisaggregationBase): "--dp", "2", "--enable-dp-attention", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -68,9 +67,8 @@ class TestDisaggregationDPAttention(TestDisaggregationBase): "--enable-dp-attention", "--base-gpu-id", "2", - "--disaggregation-ib-device", - "mlx5_roce2,mlx5_roce3", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, diff --git a/test/srt/test_disaggregation_pp.py b/test/srt/test_disaggregation_pp.py index 7367e95a0..b20ba8898 100644 --- a/test/srt/test_disaggregation_pp.py +++ b/test/srt/test_disaggregation_pp.py @@ -37,10 +37,9 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase): "2", "--pp-size", "2", - "--disaggregation-ib-device", - "mlx5_roce0,mlx5_roce1", "--disable-overlap-schedule", ] + prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -58,9 +57,8 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase): "2", "--base-gpu-id", "4", - "--disaggregation-ib-device", - "mlx5_roce4,mlx5_roce5", ] + decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url,