Fix chunked prefill with output logprob (#2083)
This commit is contained in:
@@ -302,7 +302,11 @@ class PrefillAdder:
|
|||||||
if (
|
if (
|
||||||
self.rem_chunk_tokens is None
|
self.rem_chunk_tokens is None
|
||||||
or input_tokens <= self.rem_chunk_tokens
|
or input_tokens <= self.rem_chunk_tokens
|
||||||
or (req.return_logprob and req.normalized_prompt_logprob is None)
|
or (
|
||||||
|
req.return_logprob
|
||||||
|
and req.normalized_prompt_logprob is None
|
||||||
|
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
||||||
|
)
|
||||||
):
|
):
|
||||||
# Non-chunked prefill
|
# Non-chunked prefill
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
||||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_parallel_sample
|
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -116,22 +116,47 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
print(json.dumps(response_json, indent=2))
|
print(json.dumps(response_json, indent=2))
|
||||||
|
|
||||||
for i, res in enumerate(response_json):
|
for i, res in enumerate(response_json):
|
||||||
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
|
self.assertEqual(
|
||||||
res["meta_info"]["input_token_logprobs"]
|
res["meta_info"]["prompt_tokens"],
|
||||||
|
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
|
||||||
)
|
)
|
||||||
assert prompts[i].endswith(
|
assert prompts[i].endswith(
|
||||||
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||||
)
|
)
|
||||||
|
|
||||||
assert res["meta_info"]["completion_tokens"] == new_tokens
|
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||||
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
|
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||||
res["text"] == "".join(
|
self.assertEqual(
|
||||||
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
|
res["text"],
|
||||||
|
"".join([x[-1] for x in res["meta_info"]["output_token_logprobs"]]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_logprob_with_chunked_prefill(self):
|
||||||
|
new_tokens = 4
|
||||||
|
prompts = "I have a very good idea on this. " * 8000
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompts,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": new_tokens,
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"logprob_start_len": -1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
print(json.dumps(response_json, indent=2))
|
||||||
|
|
||||||
|
res = response_json
|
||||||
|
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||||
|
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||||
|
|
||||||
def test_get_memory_pool_size(self):
|
def test_get_memory_pool_size(self):
|
||||||
response = requests.post(self.base_url + "/get_memory_pool_size")
|
response = requests.post(self.base_url + "/get_memory_pool_size")
|
||||||
assert isinstance(response.json(), int)
|
self.assertIsInstance(response.json(), int)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user