Files
sglang/test/srt/test_disaggregation_pp.py
Shangming Cai c560410da7 Refactor and optimize mooncake CI (#11162)
Signed-off-by: Shangming Cai <csmthu@gmail.com>
2025-10-05 14:08:52 -07:00

91 lines
2.4 KiB
Python

import time
import unittest
from types import SimpleNamespace
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
popen_launch_pd_server,
)
class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp-size",
"2",
"--pp-size",
"2",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
"--disable-overlap-schedule",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"2",
"--base-gpu-id",
"4",
"--disaggregation-ib-device",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.24)
# Wait a little bit so that the memory check happens.
time.sleep(5)
if __name__ == "__main__":
unittest.main()