Better unit tests for adding a new model (#1488)

This commit is contained in:
Lianmin Zheng
2024-09-22 01:50:37 -07:00
committed by GitHub
parent 441c22db8c
commit 167591e864
8 changed files with 157 additions and 126 deletions

View File

@@ -1,3 +1,11 @@
"""
Usage:
To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,69 +21,55 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import dataclasses
import multiprocessing as mp
import os
import unittest
from typing import List
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l
MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
@dataclasses.dataclass
class ModelCase:
model_path: str
tp_size: int = 1
prefill_tolerance: float = 5e-2
decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1
# Popular models that run on CI
CI_MODELS = [
ModelCase("meta-llama/Meta-Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"),
]
# All other models
ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"),
]
TORCH_DTYPES = [torch.float16]
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
return L[m][n]
def calculate_rouge_l(output_strs_list1, output_strs_list2):
rouge_l_scores = []
for s1, s2 in zip(output_strs_list1, output_strs_list2):
lcs_len = lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
return rouge_l_scores
class TestGenerationModels(unittest.TestCase):
def assert_close_prefill_logits_and_output_strs(
def assert_close_logits_and_output_strs(
self,
prompts,
model_path,
tp_size,
torch_dtype,
max_new_tokens,
prefill_tolerance,
output_tolerance,
rouge_threshold,
long_context_tolerance,
prompts: List[str],
model_case: ModelCase,
torch_dtype: torch.dtype,
) -> None:
if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
prompts = prompts[:-1]
model_path = model_case.model_path
prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
model_case.prefill_tolerance,
model_case.decode_tolerance,
model_case.rouge_l_tolerance,
)
max_new_tokens = 32
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=True
@@ -84,14 +78,14 @@ class TestGenerationModels(unittest.TestCase):
with SRTRunner(
model_path,
tp_size=tp_size,
tp_size=model_case.tp_size,
torch_dtype=torch_dtype,
is_generation=True,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)):
# input logprobs comparison
# Compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0]
@@ -99,67 +93,56 @@ class TestGenerationModels(unittest.TestCase):
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# output logprobs comparison
# Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
# print(
# "output logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print(
"output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < output_tolerance
), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# output strings comparison
print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
# Compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l(
hf_outputs.output_strs, srt_outputs.output_strs
)
print(f"rouge_l_scores={rouge_l_scores}")
print(f"{rouge_l_scores=}")
assert all(
score >= rouge_threshold for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}"
score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
def test_prefill_logits_and_output_strs(self):
for (
model,
tp_size,
long_context_tolerance,
prefill_tolerance,
output_tolerance,
rouge_threshold,
) in MODELS:
def test_ci_models(self):
for model_case in CI_MODELS:
for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32
self.assert_close_prefill_logits_and_output_strs(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
max_new_tokens,
prefill_tolerance=prefill_tolerance,
output_tolerance=output_tolerance,
rouge_threshold=rouge_threshold,
long_context_tolerance=long_context_tolerance,
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch_dtype
)
def test_others(self):
for model_case in ALL_OTHER_MODELS:
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch.float16
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
mp.set_start_method("spawn")
unittest.main()

View File

@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
)
class TestReplaceWeights(unittest.TestCase):
class TestUpdateWeights(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -33,13 +33,7 @@ class TestReplaceWeights(unittest.TestCase):
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"n": 1,
},
"stream": False,
"return_logprob": False,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
@@ -64,7 +58,7 @@ class TestReplaceWeights(unittest.TestCase):
print(json.dumps(response.json()))
return ret
def test_replace_weights(self):
def test_update_weights(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
@@ -92,7 +86,7 @@ class TestReplaceWeights(unittest.TestCase):
updated_response = self.run_decode()
assert origin_response[:32] == updated_response[:32]
def test_replace_weights_unexist_model(self):
def test_update_weights_unexist_model(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()