fix: increase max_new_tokens when testing generation models (#1244)
This commit is contained in:
@@ -30,7 +30,7 @@ DEFAULT_PROMPTS = [
|
|||||||
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
||||||
# "The capital of France is",
|
# "The capital of France is",
|
||||||
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||||
"The capital of the United Kindom is",
|
"The capital of the United Kingdom is",
|
||||||
"Today is a sunny day and I like",
|
"Today is a sunny day and I like",
|
||||||
"AI is a field of computer science focused on",
|
"AI is a field of computer science focused on",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
|
|||||||
|
|
||||||
|
|
||||||
class TestGenerationModels(unittest.TestCase):
|
class TestGenerationModels(unittest.TestCase):
|
||||||
|
|
||||||
def assert_close_prefill_logits_and_output_strs(
|
def assert_close_prefill_logits_and_output_strs(
|
||||||
self,
|
self,
|
||||||
prompts,
|
prompts,
|
||||||
@@ -99,14 +98,15 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
|
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
|
||||||
), "prefill logprobs are not all close"
|
), "prefill logprobs are not all close"
|
||||||
|
|
||||||
print(hf_outputs.output_strs)
|
print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
|
||||||
print(srt_outputs.output_strs)
|
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
|
||||||
rouge_l_scores = calculate_rouge_l(
|
rouge_l_scores = calculate_rouge_l(
|
||||||
hf_outputs.output_strs, srt_outputs.output_strs
|
hf_outputs.output_strs, srt_outputs.output_strs
|
||||||
)
|
)
|
||||||
|
print(f"rouge_l_scores={rouge_l_scores}")
|
||||||
assert all(
|
assert all(
|
||||||
score >= rouge_threshold for score in rouge_l_scores
|
score >= rouge_threshold for score in rouge_l_scores
|
||||||
), f"Not all ROUGE-L scores are greater than {rouge_threshold}"
|
), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}"
|
||||||
|
|
||||||
def test_prefill_logits_and_output_strs(self):
|
def test_prefill_logits_and_output_strs(self):
|
||||||
for (
|
for (
|
||||||
@@ -117,7 +117,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
rouge_threshold,
|
rouge_threshold,
|
||||||
) in MODELS:
|
) in MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
max_new_tokens = 8
|
max_new_tokens = 32
|
||||||
self.assert_close_prefill_logits_and_output_strs(
|
self.assert_close_prefill_logits_and_output_strs(
|
||||||
DEFAULT_PROMPTS,
|
DEFAULT_PROMPTS,
|
||||||
model,
|
model,
|
||||||
|
|||||||
Reference in New Issue
Block a user