[CI] Refactor disaggregation tests (#10068)

Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
Shangming Cai
2025-09-06 22:14:46 +08:00
committed by GitHub
parent 5f1eb20484
commit 00974e4f6e
5 changed files with 100 additions and 353 deletions

View File

@@ -1,29 +1,19 @@
import json
import os
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace
from typing import List, Optional
from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
popen_launch_pd_server,
)
class TestPDPPAccuracy(unittest.TestCase):
class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase):
"--disaggregation-mode",
"prefill",
"--tp-size",
"2",
"1",
"--pp-size",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"mlx5_roce0,mlx5_roce1",
"--disable-overlap-schedule",
]
cls.process_prefill = popen_launch_pd_server(
@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase):
"--tp",
"1",
"--base-gpu-id",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce1",
"mlx5_roce2",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase):
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase):
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval(args)
print(f"{metrics=}")