From a45a4b239df8d61b8658cb3cea38717621ca225f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 27 Apr 2025 01:03:31 -0700 Subject: [PATCH] Split local attention test from fa3 test (#5774) --- test/srt/run_suite.py | 3 +- test/srt/test_fa3.py | 17 --------- test/srt/test_local_attn.py | 72 +++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 18 deletions(-) create mode 100644 test/srt/test_local_attn.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f943a37f8..ab0dde8d7 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -30,6 +30,7 @@ suites = { TestFile("test_chunked_prefill.py", 336), TestFile("test_eagle_infer.py", 500), TestFile("test_ebnf_constrained.py"), + TestFile("test_fa3.py", 500), TestFile("test_fp8_kernel.py", 8), TestFile("test_embedding_openai_server.py", 36), TestFile("test_hidden_states.py", 55), @@ -91,7 +92,7 @@ suites = { TestFile("test_verl_engine.py", 100), ], "per-commit-8-gpu": [ - TestFile("test_fa3.py", 30), + TestFile("test_local_attn.py", 100), ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index d02c799fa..833bb3e6d 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -10,7 +10,6 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, - DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -127,22 +126,6 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest): return DEFAULT_SERVER_ARGS -class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): - """Test FlashAttention3 with Model with local attention, e.g. Llama 4.""" - - accuracy_threshold = 0.70 - model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION - - @classmethod - def get_server_args(cls): - cloned_args = DEFAULT_SERVER_ARGS.copy() - # remove --enable-torch-compile from cloned_args since llama4 does not support it for now - cloned_args.remove("--enable-torch-compile") - # we cannot use scout's 10m context due to this bug: https://github.com/sgl-project/sglang/issues/5755 - cloned_args.extend(["--tp", "4", "--context-length", "1000000"]) - return cloned_args - - class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model""" diff --git a/test/srt/test_local_attn.py b/test/srt/test_local_attn.py new file mode 100644 index 000000000..392bb0f39 --- /dev/null +++ b/test/srt/test_local_attn.py @@ -0,0 +1,72 @@ +import os +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") +class TestFlashAttention3LocalAttn(unittest.TestCase): + model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION + base_url = DEFAULT_URL_FOR_TEST + accuracy_threshold = 0.90 + + @classmethod + def get_server_args(cls): + return [ + "--trust-remote-code", + "--cuda-graph-max-bs", + "2", + "--attention-backend", + "fa3", + "--tp", + "4", + "--context-length", + "1000000", + ] + + @classmethod + def setUpClass(cls): + # disable deep gemm precompile to make launch server faster + # please don't do this if you want to make your inference workload faster + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + env=os.environ, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=4, + num_questions=100, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + data_path=None, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + # Use the appropriate metric key based on the test class + metric_key = "accuracy" + self.assertGreater(metrics[metric_key], self.accuracy_threshold) + + +if __name__ == "__main__": + unittest.main()