From 80e2c4a8de3ad34af12f6127956975b69c1beaa7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 18 Nov 2024 13:16:28 -0800 Subject: [PATCH] Fix chunked prefill with output logprob (#2083) --- python/sglang/srt/managers/schedule_policy.py | 6 ++- test/srt/test_srt_endpoint.py | 41 +++++++++++++++---- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 2bfdffc42..9b3c35c00 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -302,7 +302,11 @@ class PrefillAdder: if ( self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens - or (req.return_logprob and req.normalized_prompt_logprob is None) + or ( + req.return_logprob + and req.normalized_prompt_logprob is None + and req.logprob_start_len != len(req.origin_input_ids) - 1 + ) ): # Non-chunked prefill self.can_run_list.append(req) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index b13ed9ac8..4ca17adb6 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -1,6 +1,6 @@ """ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode -python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_parallel_sample +python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill """ import json @@ -116,22 +116,47 @@ class TestSRTEndpoint(unittest.TestCase): print(json.dumps(response_json, indent=2)) for i, res in enumerate(response_json): - assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len( - res["meta_info"]["input_token_logprobs"] + self.assertEqual( + res["meta_info"]["prompt_tokens"], + logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]), ) assert prompts[i].endswith( "".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]]) ) - assert res["meta_info"]["completion_tokens"] == new_tokens - assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens - res["text"] == "".join( - [x[-1] for x in res["meta_info"]["output_token_logprobs"]] + self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) + self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) + self.assertEqual( + res["text"], + "".join([x[-1] for x in res["meta_info"]["output_token_logprobs"]]), ) + def test_logprob_with_chunked_prefill(self): + new_tokens = 4 + prompts = "I have a very good idea on this. " * 8000 + + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompts, + "sampling_params": { + "temperature": 0, + "max_new_tokens": new_tokens, + }, + "return_logprob": True, + "logprob_start_len": -1, + }, + ) + response_json = response.json() + print(json.dumps(response_json, indent=2)) + + res = response_json + self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) + self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) + def test_get_memory_pool_size(self): response = requests.post(self.base_url + "/get_memory_pool_size") - assert isinstance(response.json(), int) + self.assertIsInstance(response.json(), int) if __name__ == "__main__":