fix: Fix returned prefill logits and add output str test (#1046)

This commit is contained in:
Ying Sheng
2024-08-11 23:13:45 -07:00
committed by GitHub
parent fb1f28cbbb
commit 32f6144323
3 changed files with 33 additions and 14 deletions

View File

@@ -21,23 +21,25 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
("google/gemma-2-2b", 1),
]
TORCH_DTYPES = [torch.float16]
class TestCausalModels(unittest.TestCase):
class TestGenerationModels(unittest.TestCase):
def assert_close_prefill_logits(
def assert_close_prefill_logits_and_output_strs(
self,
prompts,
model_path,
tp_size,
torch_dtype,
max_new_tokens,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=True
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
with SRTRunner(
model_path,
@@ -45,7 +47,7 @@ class TestCausalModels(unittest.TestCase):
torch_dtype=torch_dtype,
is_generation_model=True,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)):
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
@@ -56,11 +58,18 @@ class TestCausalModels(unittest.TestCase):
abs(hf_logprobs - srt_logprobs) < tolerance
), f"prefill logprobs not all close"
assert hf_outputs.output_strs == srt_outputs.output_strs
def test_prefill_logits(self):
for model, tp_size in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype
max_new_tokens = 8
self.assert_close_prefill_logits_and_output_strs(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
max_new_tokens,
)