fix: increase max_new_tokens when testing generation models (#1244)

This commit is contained in:
Yineng Zhang
2024-08-28 22:12:36 +10:00
committed by GitHub
parent 6c49831394
commit 66975360e7
2 changed files with 6 additions and 6 deletions

View File

@@ -30,7 +30,7 @@ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt # the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is", # "The capital of France is",
"Apple is red. Banana is Yellow. " * 800 + "Apple is", "Apple is red. Banana is Yellow. " * 800 + "Apple is",
"The capital of the United Kindom is", "The capital of the United Kingdom is",
"Today is a sunny day and I like", "Today is a sunny day and I like",
"AI is a field of computer science focused on", "AI is a field of computer science focused on",
] ]

View File

@@ -62,7 +62,6 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
class TestGenerationModels(unittest.TestCase): class TestGenerationModels(unittest.TestCase):
def assert_close_prefill_logits_and_output_strs( def assert_close_prefill_logits_and_output_strs(
self, self,
prompts, prompts,
@@ -99,14 +98,15 @@ class TestGenerationModels(unittest.TestCase):
abs(hf_logprobs - srt_logprobs) < prefill_tolerance abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), "prefill logprobs are not all close" ), "prefill logprobs are not all close"
print(hf_outputs.output_strs) print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
print(srt_outputs.output_strs) print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
rouge_l_scores = calculate_rouge_l( rouge_l_scores = calculate_rouge_l(
hf_outputs.output_strs, srt_outputs.output_strs hf_outputs.output_strs, srt_outputs.output_strs
) )
print(f"rouge_l_scores={rouge_l_scores}")
assert all( assert all(
score >= rouge_threshold for score in rouge_l_scores score >= rouge_threshold for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than {rouge_threshold}" ), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}"
def test_prefill_logits_and_output_strs(self): def test_prefill_logits_and_output_strs(self):
for ( for (
@@ -117,7 +117,7 @@ class TestGenerationModels(unittest.TestCase):
rouge_threshold, rouge_threshold,
) in MODELS: ) in MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 8 max_new_tokens = 32
self.assert_close_prefill_logits_and_output_strs( self.assert_close_prefill_logits_and_output_strs(
DEFAULT_PROMPTS, DEFAULT_PROMPTS,
model, model,