Add more logprob tests (#3162)
This commit is contained in:
@@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=("--enable-custom-logit-processor",),
|
other_args=(
|
||||||
|
"--enable-custom-logit-processor",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.8",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
"return_logprob": True,
|
"return_logprob": True,
|
||||||
"logprob_start_len": -1,
|
"logprob_start_len": -1,
|
||||||
|
"top_logprobs_num": 5,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
print(json.dumps(response_json, indent=2))
|
# print(json.dumps(response_json, indent=2))
|
||||||
|
|
||||||
res = response_json
|
res = response_json
|
||||||
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||||
|
|
||||||
|
# Test the number of tokens are correct
|
||||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||||
|
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens)
|
||||||
|
|
||||||
|
# Test the top-1 tokens are the same as output tokens (because temp = 0.0)
|
||||||
|
for i in range(new_tokens):
|
||||||
|
self.assertListEqual(
|
||||||
|
res["meta_info"]["output_token_logprobs"][i],
|
||||||
|
res["meta_info"]["output_top_logprobs"][i][0],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5)
|
||||||
|
|
||||||
def test_logprob_match(self):
|
def test_logprob_match(self):
|
||||||
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
|
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
|
||||||
@@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
max_diff = np.max(diff)
|
max_diff = np.max(diff)
|
||||||
self.assertLess(max_diff, 0.25)
|
self.assertLess(max_diff, 0.25)
|
||||||
|
|
||||||
|
def run_logprob_check(self, arg):
|
||||||
|
(
|
||||||
|
input_len,
|
||||||
|
output_len,
|
||||||
|
temperature,
|
||||||
|
logprob_start_len,
|
||||||
|
return_logprob,
|
||||||
|
top_logprobs_num,
|
||||||
|
) = arg
|
||||||
|
input_ids = list(range(input_len))
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_new_tokens": output_len,
|
||||||
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"logprob_start_len": logprob_start_len,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
res = response_json
|
||||||
|
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||||
|
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||||
|
|
||||||
|
# Test the number of tokens are correct
|
||||||
|
if return_logprob:
|
||||||
|
# This is because if logprob_start_len == 0, we added a padding for the first token.
|
||||||
|
# In other cases, we do not add the padding
|
||||||
|
delta = 0 if logprob_start_len == 0 else 1
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(res["meta_info"]["input_token_logprobs"])
|
||||||
|
+ logprob_start_len
|
||||||
|
+ delta,
|
||||||
|
res["meta_info"]["prompt_tokens"],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||||
|
|
||||||
|
if top_logprobs_num:
|
||||||
|
self.assertEqual(
|
||||||
|
len(res["meta_info"]["input_top_logprobs"])
|
||||||
|
+ logprob_start_len
|
||||||
|
+ delta,
|
||||||
|
res["meta_info"]["prompt_tokens"],
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
len(res["meta_info"]["output_top_logprobs"]), output_len
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(output_len):
|
||||||
|
self.assertEqual(
|
||||||
|
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||||
|
top_logprobs_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||||
|
if temperature == 0:
|
||||||
|
self.assertListEqual(
|
||||||
|
res["meta_info"]["output_token_logprobs"][i],
|
||||||
|
res["meta_info"]["output_top_logprobs"][i][0],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_logprob_mixed(self):
|
||||||
|
args = []
|
||||||
|
temperature = 0
|
||||||
|
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
|
||||||
|
for input_len in [1000, 2000]:
|
||||||
|
for output_len in [4, 8]:
|
||||||
|
for logprob_start_len in [0, 500, 1000]:
|
||||||
|
for return_logprob in [True, False]:
|
||||||
|
for top_logprobs_num in [0, 5]:
|
||||||
|
|
||||||
|
if logprob_start_len >= input_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
args.append(
|
||||||
|
(
|
||||||
|
input_len,
|
||||||
|
output_len,
|
||||||
|
temperature,
|
||||||
|
logprob_start_len,
|
||||||
|
return_logprob,
|
||||||
|
top_logprobs_num,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
random.shuffle(args)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(8) as executor:
|
||||||
|
list(executor.map(self.run_logprob_check, args))
|
||||||
|
|
||||||
def test_logprob_grammar(self):
|
def test_logprob_grammar(self):
|
||||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||||
allowed_tokens = [" Yes", " No"]
|
allowed_tokens = [" Yes", " No"]
|
||||||
|
|||||||
Reference in New Issue
Block a user