[Fix] Fix logprob and normalized_logprob (#1428)
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
@@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 32,
|
||||
"max_new_tokens": 16,
|
||||
"n": n,
|
||||
},
|
||||
"stream": stream,
|
||||
@@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
for line in response.iter_lines():
|
||||
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||
response_json.append(json.loads(line[6:]))
|
||||
print(json.dumps(response_json))
|
||||
|
||||
print(json.dumps(response_json, indent=2))
|
||||
print("=" * 100)
|
||||
|
||||
def test_simple_decode(self):
|
||||
@@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
self.run_decode(n=3, stream=True)
|
||||
|
||||
def test_logprob(self):
|
||||
for top_logprobs_num in [0, 3]:
|
||||
for return_text in [True, False]:
|
||||
self.run_decode(
|
||||
return_logprob=True,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text=return_text,
|
||||
)
|
||||
self.run_decode(
|
||||
return_logprob=True,
|
||||
top_logprobs_num=5,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
def test_logprob_start_len(self):
|
||||
logprob_start_len = 4
|
||||
new_tokens = 4
|
||||
prompts = [
|
||||
"I have a very good idea on",
|
||||
"Today is a sunndy day and",
|
||||
]
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": prompts,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": new_tokens,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 5,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
print(json.dumps(response_json, indent=2))
|
||||
|
||||
for i, res in enumerate(response_json):
|
||||
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
|
||||
res["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
assert prompts[i].endswith(
|
||||
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||
)
|
||||
|
||||
assert res["meta_info"]["completion_tokens"] == new_tokens
|
||||
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
|
||||
res["text"] == "".join(
|
||||
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user