[PD] Support PD disaggregation with Prefill PP (#8846)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: root <huzhiyuan@xiaohongshu.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com>
Co-authored-by: zitto <zhjc1124@gmail.com>
This commit is contained in:
Shangming Cai
2025-08-17 09:31:31 +08:00
committed by GitHub
parent 6a9d6ca33c
commit 384f8ab5ce
11 changed files with 632 additions and 82 deletions

View File

@@ -9,6 +9,8 @@ import time
import unittest
from types import SimpleNamespace
import requests
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
@@ -62,6 +64,29 @@ class TestPPAccuracy(unittest.TestCase):
# Wait a little bit so that the memory check happens.
time.sleep(4)
def test_logprob(self):
response = requests.post(
f"{self.base_url}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
"return_logprob": True,
"top_logprobs_num": 5,
"logprob_start_len": 0,
},
)
response_json = response.json()
input_token_logprobs = response_json["meta_info"]["input_token_logprobs"]
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"]
assert len(input_token_logprobs) == 6
assert len(output_token_logprobs) == 16
assert len(output_top_logprobs) == 16
class TestQwenPPAccuracy(unittest.TestCase):
@classmethod