[CI] Return output logprobs in unit test (#1361)
This commit is contained in:
@@ -21,9 +21,9 @@ import torch
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
|
||||
MODELS = [
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1),
|
||||
("google/gemma-2-2b", 1, 3, 3e-2, 1),
|
||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1),
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
|
||||
("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
|
||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
@@ -70,6 +70,7 @@ class TestGenerationModels(unittest.TestCase):
|
||||
torch_dtype,
|
||||
max_new_tokens,
|
||||
prefill_tolerance,
|
||||
output_tolerance,
|
||||
rouge_threshold,
|
||||
long_context_tolerance,
|
||||
) -> None:
|
||||
@@ -89,15 +90,37 @@ class TestGenerationModels(unittest.TestCase):
|
||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
# input logprobs comparison
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||
|
||||
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
|
||||
if hf_logprobs.shape[0] <= 100:
|
||||
input_len = hf_logprobs.shape[0]
|
||||
print(
|
||||
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(
|
||||
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
|
||||
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
|
||||
|
||||
# output logprobs comparison
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
||||
# print(
|
||||
# "output logprobs diff",
|
||||
# [
|
||||
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
|
||||
# for j in range(max_new_tokens)
|
||||
# ],
|
||||
# )
|
||||
print(
|
||||
"output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(
|
||||
abs(hf_logprobs - srt_logprobs) < output_tolerance
|
||||
), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"
|
||||
|
||||
# output strings comparison
|
||||
print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
|
||||
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
|
||||
rouge_l_scores = calculate_rouge_l(
|
||||
@@ -114,6 +137,7 @@ class TestGenerationModels(unittest.TestCase):
|
||||
tp_size,
|
||||
long_context_tolerance,
|
||||
prefill_tolerance,
|
||||
output_tolerance,
|
||||
rouge_threshold,
|
||||
) in MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
@@ -125,6 +149,7 @@ class TestGenerationModels(unittest.TestCase):
|
||||
torch_dtype,
|
||||
max_new_tokens,
|
||||
prefill_tolerance=prefill_tolerance,
|
||||
output_tolerance=output_tolerance,
|
||||
rouge_threshold=rouge_threshold,
|
||||
long_context_tolerance=long_context_tolerance,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user