Add more logprob tests (#3162)
This commit is contained in:
@@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=("--enable-custom-logit-processor",),
|
||||
other_args=(
|
||||
"--enable-custom-logit-processor",
|
||||
"--mem-fraction-static",
|
||||
"0.8",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": -1,
|
||||
"top_logprobs_num": 5,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
print(json.dumps(response_json, indent=2))
|
||||
# print(json.dumps(response_json, indent=2))
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||
|
||||
# Test the number of tokens are correct
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens)
|
||||
|
||||
# Test the top-1 tokens are the same as output tokens (because temp = 0.0)
|
||||
for i in range(new_tokens):
|
||||
self.assertListEqual(
|
||||
res["meta_info"]["output_token_logprobs"][i],
|
||||
res["meta_info"]["output_top_logprobs"][i][0],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5)
|
||||
|
||||
def test_logprob_match(self):
|
||||
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
|
||||
@@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
max_diff = np.max(diff)
|
||||
self.assertLess(max_diff, 0.25)
|
||||
|
||||
def run_logprob_check(self, arg):
|
||||
(
|
||||
input_len,
|
||||
output_len,
|
||||
temperature,
|
||||
logprob_start_len,
|
||||
return_logprob,
|
||||
top_logprobs_num,
|
||||
) = arg
|
||||
input_ids = list(range(input_len))
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||
|
||||
# Test the number of tokens are correct
|
||||
if return_logprob:
|
||||
# This is because if logprob_start_len == 0, we added a padding for the first token.
|
||||
# In other cases, we do not add the padding
|
||||
delta = 0 if logprob_start_len == 0 else 1
|
||||
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_token_logprobs"])
|
||||
+ logprob_start_len
|
||||
+ delta,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||
|
||||
if top_logprobs_num:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_top_logprobs"])
|
||||
+ logprob_start_len
|
||||
+ delta,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"]), output_len
|
||||
)
|
||||
|
||||
for i in range(output_len):
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||
top_logprobs_num,
|
||||
)
|
||||
|
||||
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||
if temperature == 0:
|
||||
self.assertListEqual(
|
||||
res["meta_info"]["output_token_logprobs"][i],
|
||||
res["meta_info"]["output_top_logprobs"][i][0],
|
||||
)
|
||||
|
||||
def test_logprob_mixed(self):
|
||||
args = []
|
||||
temperature = 0
|
||||
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
|
||||
for input_len in [1000, 2000]:
|
||||
for output_len in [4, 8]:
|
||||
for logprob_start_len in [0, 500, 1000]:
|
||||
for return_logprob in [True, False]:
|
||||
for top_logprobs_num in [0, 5]:
|
||||
|
||||
if logprob_start_len >= input_len:
|
||||
continue
|
||||
|
||||
args.append(
|
||||
(
|
||||
input_len,
|
||||
output_len,
|
||||
temperature,
|
||||
logprob_start_len,
|
||||
return_logprob,
|
||||
top_logprobs_num,
|
||||
)
|
||||
)
|
||||
|
||||
random.shuffle(args)
|
||||
|
||||
with ThreadPoolExecutor(8) as executor:
|
||||
list(executor.map(self.run_logprob_check, args))
|
||||
|
||||
def test_logprob_grammar(self):
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
allowed_tokens = [" Yes", " No"]
|
||||
|
||||
Reference in New Issue
Block a user