Files
sglang/test/srt/test_disaggregation_pp.py
2025-10-16 11:42:42 -07:00

90 lines
2.4 KiB
Python

import time
import unittest
from types import SimpleNamespace
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
popen_launch_pd_server,
try_cached_model,
)
class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = try_cached_model(DEFAULT_MODEL_NAME_FOR_TEST)
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp-size",
"2",
"--pp-size",
"2",
"--disable-overlap-schedule",
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"2",
"--base-gpu-id",
"4",
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.24)
# Wait a little bit so that the memory check happens.
time.sleep(5)
if __name__ == "__main__":
unittest.main()