Fix chunked prefill with output logprob (#2083)
This commit is contained in:
@@ -302,7 +302,11 @@ class PrefillAdder:
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
or input_tokens <= self.rem_chunk_tokens
|
||||
or (req.return_logprob and req.normalized_prompt_logprob is None)
|
||||
or (
|
||||
req.return_logprob
|
||||
and req.normalized_prompt_logprob is None
|
||||
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
||||
)
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_parallel_sample
|
||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -116,22 +116,47 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
print(json.dumps(response_json, indent=2))
|
||||
|
||||
for i, res in enumerate(response_json):
|
||||
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
|
||||
res["meta_info"]["input_token_logprobs"]
|
||||
self.assertEqual(
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
|
||||
)
|
||||
assert prompts[i].endswith(
|
||||
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||
)
|
||||
|
||||
assert res["meta_info"]["completion_tokens"] == new_tokens
|
||||
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
|
||||
res["text"] == "".join(
|
||||
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||
self.assertEqual(
|
||||
res["text"],
|
||||
"".join([x[-1] for x in res["meta_info"]["output_token_logprobs"]]),
|
||||
)
|
||||
|
||||
def test_logprob_with_chunked_prefill(self):
|
||||
new_tokens = 4
|
||||
prompts = "I have a very good idea on this. " * 8000
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": prompts,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": new_tokens,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": -1,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
print(json.dumps(response_json, indent=2))
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||
|
||||
def test_get_memory_pool_size(self):
|
||||
response = requests.post(self.base_url + "/get_memory_pool_size")
|
||||
assert isinstance(response.json(), int)
|
||||
self.assertIsInstance(response.json(), int)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user