From 66975360e7575a5f573cdaf5c6892d81afc3ed19 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 28 Aug 2024 22:12:36 +1000 Subject: [PATCH] fix: increase max_new_tokens when testing generation models (#1244) --- python/sglang/test/runners.py | 2 +- test/srt/models/test_generation_models.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 37ed2cf9a..e69d699a7 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -30,7 +30,7 @@ DEFAULT_PROMPTS = [ # the output of gemma-2-2b from SRT is unstable on the commented prompt # "The capital of France is", "Apple is red. Banana is Yellow. " * 800 + "Apple is", - "The capital of the United Kindom is", + "The capital of the United Kingdom is", "Today is a sunny day and I like", "AI is a field of computer science focused on", ] diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index b953ccf5d..e38584741 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -62,7 +62,6 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2): class TestGenerationModels(unittest.TestCase): - def assert_close_prefill_logits_and_output_strs( self, prompts, @@ -99,14 +98,15 @@ class TestGenerationModels(unittest.TestCase): abs(hf_logprobs - srt_logprobs) < prefill_tolerance ), "prefill logprobs are not all close" - print(hf_outputs.output_strs) - print(srt_outputs.output_strs) + print(f"hf_outputs.output_strs={hf_outputs.output_strs}") + print(f"srt_outputs.output_strs={srt_outputs.output_strs}") rouge_l_scores = calculate_rouge_l( hf_outputs.output_strs, srt_outputs.output_strs ) + print(f"rouge_l_scores={rouge_l_scores}") assert all( score >= rouge_threshold for score in rouge_l_scores - ), f"Not all ROUGE-L scores are greater than {rouge_threshold}" + ), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}" def test_prefill_logits_and_output_strs(self): for ( @@ -117,7 +117,7 @@ class TestGenerationModels(unittest.TestCase): rouge_threshold, ) in MODELS: for torch_dtype in TORCH_DTYPES: - max_new_tokens = 8 + max_new_tokens = 32 self.assert_close_prefill_logits_and_output_strs( DEFAULT_PROMPTS, model,