From af02f99b7ccbddd74bae98961428c32ae07d6079 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 26 Jan 2025 22:24:55 -0800 Subject: [PATCH] Add more logprob tests (#3162) --- test/srt/test_srt_endpoint.py | 117 +++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 2 deletions(-) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index b4e71183d..68db1d699 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=("--enable-custom-logit-processor",), + other_args=( + "--enable-custom-logit-processor", + "--mem-fraction-static", + "0.8", + ), ) @classmethod @@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase): }, "return_logprob": True, "logprob_start_len": -1, + "top_logprobs_num": 5, }, ) response_json = response.json() - print(json.dumps(response_json, indent=2)) + # print(json.dumps(response_json, indent=2)) res = response_json self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) + + # Test the number of tokens are correct self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) + self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens) + + # Test the top-1 tokens are the same as output tokens (because temp = 0.0) + for i in range(new_tokens): + self.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][0], + ) + self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5) def test_logprob_match(self): """Test the output logprobs are close to the input logprobs if we run a prefill again.""" @@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase): max_diff = np.max(diff) self.assertLess(max_diff, 0.25) + def run_logprob_check(self, arg): + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) = arg + input_ids = list(range(input_len)) + + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": output_len, + }, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + }, + ) + response_json = response.json() + + res = response_json + self.assertEqual(res["meta_info"]["prompt_tokens"], input_len) + self.assertEqual(res["meta_info"]["completion_tokens"], output_len) + + # Test the number of tokens are correct + if return_logprob: + # This is because if logprob_start_len == 0, we added a padding for the first token. + # In other cases, we do not add the padding + delta = 0 if logprob_start_len == 0 else 1 + + self.assertEqual( + len(res["meta_info"]["input_token_logprobs"]) + + logprob_start_len + + delta, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len) + + if top_logprobs_num: + self.assertEqual( + len(res["meta_info"]["input_top_logprobs"]) + + logprob_start_len + + delta, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual( + len(res["meta_info"]["output_top_logprobs"]), output_len + ) + + for i in range(output_len): + self.assertEqual( + len(res["meta_info"]["output_top_logprobs"][i]), + top_logprobs_num, + ) + + # Test the top-1 tokens are the same as output tokens if temperature == 0 + if temperature == 0: + self.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][0], + ) + + def test_logprob_mixed(self): + args = [] + temperature = 0 + # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num + for input_len in [1000, 2000]: + for output_len in [4, 8]: + for logprob_start_len in [0, 500, 1000]: + for return_logprob in [True, False]: + for top_logprobs_num in [0, 5]: + + if logprob_start_len >= input_len: + continue + + args.append( + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) + ) + + random.shuffle(args) + + with ThreadPoolExecutor(8) as executor: + list(executor.map(self.run_logprob_check, args)) + def test_logprob_grammar(self): prompts = "Question: Is Paris the Capital of France? Answer:" allowed_tokens = [" Yes", " No"]