[PD] Support PD disaggregation with Prefill PP (#8846)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: root <huzhiyuan@xiaohongshu.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com>
Co-authored-by: zitto <zhjc1124@gmail.com>
This commit is contained in:
Shangming Cai
2025-08-17 09:31:31 +08:00
committed by GitHub
parent 6a9d6ca33c
commit 384f8ab5ce
11 changed files with 632 additions and 82 deletions

View File

@@ -0,0 +1,133 @@
import json
import os
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace
from typing import List, Optional
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestPDPPAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp-size",
"2",
"--pp-size",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"--disable-overlap-schedule",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.24)
# Wait a little bit so that the memory check happens.
time.sleep(5)
if __name__ == "__main__":
unittest.main()

View File

@@ -9,6 +9,8 @@ import time
import unittest
from types import SimpleNamespace
import requests
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
@@ -62,6 +64,29 @@ class TestPPAccuracy(unittest.TestCase):
# Wait a little bit so that the memory check happens.
time.sleep(4)
def test_logprob(self):
response = requests.post(
f"{self.base_url}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
"return_logprob": True,
"top_logprobs_num": 5,
"logprob_start_len": 0,
},
)
response_json = response.json()
input_token_logprobs = response_json["meta_info"]["input_token_logprobs"]
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"]
assert len(input_token_logprobs) == 6
assert len(output_token_logprobs) == 16
assert len(output_top_logprobs) == 16
class TestQwenPPAccuracy(unittest.TestCase):
@classmethod