[CI] Return output logprobs in unit test (#1361)

This commit is contained in:
Ying Sheng
2024-09-09 13:05:13 -07:00
committed by GitHub
parent a7c47e0f02
commit 689ff588ec
2 changed files with 73 additions and 21 deletions

View File

@@ -21,9 +21,9 @@ import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1),
("google/gemma-2-2b", 1, 3, 3e-2, 1),
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1),
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
]
TORCH_DTYPES = [torch.float16]
@@ -70,6 +70,7 @@ class TestGenerationModels(unittest.TestCase):
torch_dtype,
max_new_tokens,
prefill_tolerance,
output_tolerance,
rouge_threshold,
long_context_tolerance,
) -> None:
@@ -89,15 +90,37 @@ class TestGenerationModels(unittest.TestCase):
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)):
# input logprobs comparison
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
if hf_logprobs.shape[0] <= 100:
input_len = hf_logprobs.shape[0]
print(
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
# output logprobs comparison
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
# print(
# "output logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print(
"output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < output_tolerance
), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"
# output strings comparison
print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
rouge_l_scores = calculate_rouge_l(
@@ -114,6 +137,7 @@ class TestGenerationModels(unittest.TestCase):
tp_size,
long_context_tolerance,
prefill_tolerance,
output_tolerance,
rouge_threshold,
) in MODELS:
for torch_dtype in TORCH_DTYPES:
@@ -125,6 +149,7 @@ class TestGenerationModels(unittest.TestCase):
torch_dtype,
max_new_tokens,
prefill_tolerance=prefill_tolerance,
output_tolerance=output_tolerance,
rouge_threshold=rouge_threshold,
long_context_tolerance=long_context_tolerance,
)