diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index ccff4524f..87dad3ed0 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): return reqs +@torch.inference_mode() def extend(reqs, model_runner): batch = ScheduleBatch.init_new( reqs=reqs, @@ -235,6 +236,7 @@ def extend(reqs, model_runner): return next_token_ids, logits_output.next_token_logits, batch +@torch.inference_mode() def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids) model_worker_batch = batch.get_model_worker_batch() @@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner): return next_token_ids, logits_output.next_token_logits -@torch.inference_mode() def correctness_test( server_args, port_args, @@ -287,7 +288,6 @@ def correctness_test( rank_print(tokenizer.decode(output_ids[i]), "\n") -@torch.inference_mode() def latency_test_run_once( run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len ): diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index fadc6dd50..802f40d7d 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -42,13 +42,13 @@ class ModelCase: rouge_l_tolerance: float = 1 -# Popular models that run on CI +# Popular models that run on the CI CI_MODELS = [ ModelCase("meta-llama/Llama-3.1-8B-Instruct"), ModelCase("google/gemma-2-2b"), ] -# All other models +# All other models that do not run on the CI ALL_OTHER_MODELS = [ ModelCase("Qwen/Qwen2-1.5B"), ModelCase("Qwen/Qwen2.5-14B-Instruct"), @@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16] class TestGenerationModels(unittest.TestCase): + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn") + def assert_close_logits_and_output_strs( self, prompts: List[str], @@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase): return for model_case in ALL_OTHER_MODELS: + # Only run a specified model if ( "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.model_path ): continue - self.assert_close_logits_and_output_strs( - DEFAULT_PROMPTS, model_case, torch.float16 - ) + + # Skip long prompts for models that does not have a long context + prompts = DEFAULT_PROMPTS + if model_case.model_path in ["HuggingFaceTB/SmolLM-135M-Instruct"]: + prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] + + # Assert the logits and output strs are close + self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16) if __name__ == "__main__": - mp.set_start_method("spawn") unittest.main()