2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-09-22 01:50:37 -07:00
"""
Usage :
2025-05-13 00:16:10 -07:00
To test a specific model locally :
1. Add it to ALL_MODELS , for example , ` ModelCase ( " Qwen/Qwen2-1.5B " ) `
2. Run ` ONLY_RUN = Qwen / Qwen2 - 1.5 B python3 - m unittest test_generation_models . TestGenerationModels `
2024-09-22 01:50:37 -07:00
"""
import dataclasses
2024-08-25 16:21:37 -07:00
import multiprocessing as mp
2024-09-22 01:50:37 -07:00
import os
2025-05-13 00:16:10 -07:00
import random
2024-08-03 18:20:50 -07:00
import unittest
2024-09-22 01:50:37 -07:00
from typing import List
2024-08-03 18:20:50 -07:00
import torch
2025-03-01 01:53:10 +08:00
from sglang . test . runners import (
DEFAULT_PROMPTS ,
HFRunner ,
SRTRunner ,
check_close_model_outputs ,
)
2025-03-26 07:53:12 +08:00
from sglang . test . test_utils import CustomTestCase , is_in_ci
2024-08-03 18:20:50 -07:00
2024-08-26 01:29:12 +08:00
2024-09-22 01:50:37 -07:00
@dataclasses.dataclass
class ModelCase :
model_path : str
tp_size : int = 1
prefill_tolerance : float = 5e-2
decode_tolerance : float = 5e-2
rouge_l_tolerance : float = 1
2024-10-17 09:23:29 -07:00
skip_long_prompt : bool = False
2025-03-01 01:53:10 +08:00
trust_remote_code : bool = False
2024-08-26 01:29:12 +08:00
2024-10-11 05:03:20 -07:00
# Popular models that run on the CI
2024-09-22 01:50:37 -07:00
CI_MODELS = [
2024-10-02 10:12:07 -07:00
ModelCase ( " meta-llama/Llama-3.1-8B-Instruct " ) ,
2025-07-27 21:27:25 -07:00
# TODO: Gemma is broken by the bug introduced in the latest transformers version, we should restore once its fixed: https://github.com/huggingface/transformers/issues/39711
# ModelCase("google/gemma-2-2b"),
2024-09-22 01:50:37 -07:00
]
2024-08-26 01:29:12 +08:00
2025-05-13 00:16:10 -07:00
# the complete set of models to test sglang's generation model
ALL_MODELS = [
* CI_MODELS ,
2024-09-22 01:50:37 -07:00
ModelCase ( " Qwen/Qwen2-1.5B " ) ,
2024-10-03 12:41:15 +09:00
ModelCase ( " Qwen/Qwen2.5-14B-Instruct " ) ,
2024-10-17 09:23:29 -07:00
ModelCase ( " HuggingFaceTB/SmolLM-135M-Instruct " , skip_long_prompt = True ) ,
ModelCase ( " allenai/OLMo-1B-0724-hf " , decode_tolerance = 8e-2 , skip_long_prompt = True ) ,
2025-03-01 01:53:10 +08:00
ModelCase (
" THUDM/glm-4-9b-chat " , tp_size = 2 , trust_remote_code = True , skip_long_prompt = True
) ,
2024-11-07 15:42:47 -08:00
ModelCase ( " openai-community/gpt2 " ) ,
2025-07-14 07:13:40 +05:30
ModelCase ( " microsoft/phi-1_5 " , trust_remote_code = True ) ,
2025-07-20 11:37:47 +05:30
ModelCase ( " adept/persimmon-8b-chat " ) ,
2025-05-13 00:16:10 -07:00
ModelCase ( " microsoft/Phi-3-small-8k-instruct " , trust_remote_code = True ) ,
2024-11-28 10:15:20 +02:00
ModelCase ( " allenai/OLMo-2-1124-7B-Instruct " , skip_long_prompt = True ) ,
2024-12-11 06:30:23 -08:00
ModelCase ( " ibm-granite/granite-3.0-2b-instruct " , skip_long_prompt = True ) ,
2025-07-09 23:51:33 -07:00
ModelCase (
" microsoft/Phi-3.5-MoE-instruct " ,
tp_size = 2 ,
trust_remote_code = True ,
skip_long_prompt = True ,
) ,
2024-09-22 01:50:37 -07:00
]
2024-08-26 01:29:12 +08:00
2024-09-22 01:50:37 -07:00
TORCH_DTYPES = [ torch . float16 ]
2024-08-26 01:29:12 +08:00
2025-03-26 07:53:12 +08:00
class TestGenerationModels ( CustomTestCase ) :
2024-11-03 13:27:12 -08:00
2024-10-11 05:03:20 -07:00
@classmethod
def setUpClass ( cls ) :
2024-11-03 13:27:12 -08:00
mp . set_start_method ( " spawn " , force = True )
2024-10-11 05:03:20 -07:00
2024-09-22 01:50:37 -07:00
def assert_close_logits_and_output_strs (
2024-08-03 18:20:50 -07:00
self ,
2024-09-22 01:50:37 -07:00
prompts : List [ str ] ,
model_case : ModelCase ,
torch_dtype : torch . dtype ,
2024-08-03 18:20:50 -07:00
) - > None :
2024-09-22 01:50:37 -07:00
model_path = model_case . model_path
prefill_tolerance , decode_tolerance , rouge_l_tolerance = (
model_case . prefill_tolerance ,
model_case . decode_tolerance ,
model_case . rouge_l_tolerance ,
)
max_new_tokens = 32
2024-09-12 16:46:14 -07:00
2024-08-03 18:20:50 -07:00
with HFRunner (
2024-09-27 23:32:11 -07:00
model_path ,
torch_dtype = torch_dtype ,
model_type = " generation " ,
2025-03-01 01:53:10 +08:00
trust_remote_code = model_case . trust_remote_code ,
2024-08-03 18:20:50 -07:00
) as hf_runner :
2024-08-11 23:13:45 -07:00
hf_outputs = hf_runner . forward ( prompts , max_new_tokens = max_new_tokens )
2024-08-03 18:20:50 -07:00
with SRTRunner (
model_path ,
2024-09-22 01:50:37 -07:00
tp_size = model_case . tp_size ,
2024-08-03 18:20:50 -07:00
torch_dtype = torch_dtype ,
2024-09-27 23:32:11 -07:00
model_type = " generation " ,
2025-03-01 01:53:10 +08:00
trust_remote_code = model_case . trust_remote_code ,
2024-08-03 18:20:50 -07:00
) as srt_runner :
2024-08-11 23:13:45 -07:00
srt_outputs = srt_runner . forward ( prompts , max_new_tokens = max_new_tokens )
2024-08-03 18:20:50 -07:00
2025-03-01 01:53:10 +08:00
check_close_model_outputs (
hf_outputs = hf_outputs ,
srt_outputs = srt_outputs ,
prefill_tolerance = model_case . prefill_tolerance ,
decode_tolerance = model_case . decode_tolerance ,
rouge_l_tolerance = model_case . rouge_l_tolerance ,
debug_text = f " model_path= { model_path } prompts= { prompts } " ,
2024-08-26 01:29:12 +08:00
)
2024-09-22 01:50:37 -07:00
2025-05-13 00:16:10 -07:00
@unittest.skipIf ( not is_in_ci ( ) , " Local test should run all models " )
2024-09-22 01:50:37 -07:00
def test_ci_models ( self ) :
for model_case in CI_MODELS :
2024-08-03 18:20:50 -07:00
for torch_dtype in TORCH_DTYPES :
2025-05-13 00:16:10 -07:00
prompts = DEFAULT_PROMPTS
2024-10-17 09:23:29 -07:00
# Skip long prompts for models that do not have a long context
if model_case . skip_long_prompt :
prompts = [ p for p in DEFAULT_PROMPTS if len ( p ) < 1000 ]
# Assert the logits and output strs are close
2024-09-22 01:50:37 -07:00
self . assert_close_logits_and_output_strs (
2024-10-17 09:23:29 -07:00
prompts , model_case , torch_dtype
2024-08-03 18:20:50 -07:00
)
2025-05-13 00:16:10 -07:00
@unittest.skipIf ( is_in_ci ( ) , " CI only runs selected models for simplicity " )
def test_all_models ( self ) :
for model_case in ALL_MODELS :
for torch_dtype in TORCH_DTYPES :
if (
" ONLY_RUN " in os . environ
and os . environ [ " ONLY_RUN " ] != model_case . model_path
) :
continue
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case . skip_long_prompt :
prompts = [ p for p in DEFAULT_PROMPTS if len ( p ) < 1000 ]
# Assert the logits and output strs are close
self . assert_close_logits_and_output_strs (
prompts , model_case , torch_dtype
)
2024-08-03 18:20:50 -07:00
2024-08-25 16:21:37 -07:00
2024-09-22 01:50:37 -07:00
if __name__ == " __main__ " :
2024-08-25 19:56:42 -07:00
unittest . main ( )