[CI] Refactor disaggregation tests (#10068)

Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
Shangming Cai
2025-09-06 22:14:46 +08:00
committed by GitHub
parent 5f1eb20484
commit 00974e4f6e
5 changed files with 100 additions and 353 deletions

View File

@@ -0,0 +1,66 @@
import time
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
CustomTestCase,
popen_with_error_check,
)
class TestDisaggregationBase(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
pass
@classmethod
def launch_lb(cls):
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)

View File

@@ -139,6 +139,7 @@ suites = {
TestFile("lora/test_lora_llama4.py", 600),
TestFile("test_disaggregation.py", 499),
TestFile("test_disaggregation_different_tp.py", 155),
TestFile("test_disaggregation_pp.py", 60),
TestFile("test_full_deepseek_v3.py", 333),
],
"per-commit-8-gpu-b200": [

View File

@@ -7,21 +7,19 @@ from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server,
popen_with_error_check,
)
class TestDisaggregationAccuracy(CustomTestCase):
class TestDisaggregationAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
@@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
@@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
json.loads(output)
class TestDisaggregationMooncakeFailure(CustomTestCase):
class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
@@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
cls.launch_lb()
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def tearDownClass(cls):
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
super().tearDownClass()
@classmethod
def start_prefill(cls):
@@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
@@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
raise e from health_check_error
class TestDisaggregationMooncakeSpec(CustomTestCase):
class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
@@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
cls.launch_lb()
@classmethod
def start_prefill(cls):
@@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
@@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.20)
class TestDisaggregationSimulatedRetract(CustomTestCase):
class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "true"
@@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
cls.launch_lb()
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_TEST_RETRACT")
super().tearDownClass()
@classmethod
def start_prefill(cls):
@@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_TEST_RETRACT")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,

View File

@@ -1,25 +1,20 @@
import os
import subprocess
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server,
popen_with_error_check,
)
class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
# Temporarily disable JIT DeepGEMM
@@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
@@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# Restore JIT DeepGEMM environment variable
if cls.original_jit_deepgemm is not None:
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
else:
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
@@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.60)
class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
# Temporarily disable JIT DeepGEMM
@@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
@@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# Restore JIT DeepGEMM environment variable
if cls.original_jit_deepgemm is not None:
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
else:
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,

View File

@@ -1,29 +1,19 @@
import json
import os
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace
from typing import List, Optional
from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
popen_launch_pd_server,
)
class TestPDPPAccuracy(unittest.TestCase):
class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase):
"--disaggregation-mode",
"prefill",
"--tp-size",
"2",
"1",
"--pp-size",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"mlx5_roce0,mlx5_roce1",
"--disable-overlap-schedule",
]
cls.process_prefill = popen_launch_pd_server(
@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase):
"--tp",
"1",
"--base-gpu-id",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce1",
"mlx5_roce2",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase):
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase):
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval(args)
print(f"{metrics=}")