[Feature] SPMD for SGLang + Verl (#3852)
This commit is contained in:
@@ -27,8 +27,13 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
|
||||
from sglang.test.runners import (
|
||||
DEFAULT_PROMPTS,
|
||||
HFRunner,
|
||||
SRTRunner,
|
||||
check_close_model_outputs,
|
||||
)
|
||||
from sglang.test.test_utils import is_in_ci
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -39,6 +44,7 @@ class ModelCase:
|
||||
decode_tolerance: float = 5e-2
|
||||
rouge_l_tolerance: float = 1
|
||||
skip_long_prompt: bool = False
|
||||
trust_remote_code: bool = False
|
||||
|
||||
|
||||
# Popular models that run on the CI
|
||||
@@ -53,7 +59,9 @@ ALL_OTHER_MODELS = [
|
||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
|
||||
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
|
||||
ModelCase("THUDM/glm-4-9b-chat"),
|
||||
ModelCase(
|
||||
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
|
||||
),
|
||||
ModelCase("openai-community/gpt2"),
|
||||
ModelCase("microsoft/Phi-3-small-8k-instruct"),
|
||||
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
|
||||
@@ -87,6 +95,7 @@ class TestGenerationModels(unittest.TestCase):
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
trust_remote_code=model_case.trust_remote_code,
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
@@ -95,48 +104,18 @@ class TestGenerationModels(unittest.TestCase):
|
||||
tp_size=model_case.tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
trust_remote_code=model_case.trust_remote_code,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
# Compare input logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||
input_len = hf_logprobs.shape[0]
|
||||
print(
|
||||
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
||||
f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
|
||||
f"prefill_tolerance={prefill_tolerance}."
|
||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||
)
|
||||
|
||||
# Compare output logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
||||
|
||||
print(
|
||||
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
||||
f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
|
||||
f"decode_tolerance={decode_tolerance}."
|
||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||
)
|
||||
|
||||
# Compare output strings
|
||||
print(f"{hf_outputs.output_strs=}")
|
||||
print(f"{srt_outputs.output_strs=}")
|
||||
rouge_l_scores = calculate_rouge_l(
|
||||
hf_outputs.output_strs, srt_outputs.output_strs
|
||||
check_close_model_outputs(
|
||||
hf_outputs=hf_outputs,
|
||||
srt_outputs=srt_outputs,
|
||||
prefill_tolerance=model_case.prefill_tolerance,
|
||||
decode_tolerance=model_case.decode_tolerance,
|
||||
rouge_l_tolerance=model_case.rouge_l_tolerance,
|
||||
debug_text=f"model_path={model_path} prompts={prompts}",
|
||||
)
|
||||
print(f"{rouge_l_scores=}")
|
||||
assert all(
|
||||
score >= rouge_l_tolerance for score in rouge_l_scores
|
||||
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
|
||||
|
||||
def test_ci_models(self):
|
||||
for model_case in CI_MODELS:
|
||||
|
||||
Reference in New Issue
Block a user