diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ab0dde8d7..fa99c1ff5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -30,7 +30,7 @@ suites = { TestFile("test_chunked_prefill.py", 336), TestFile("test_eagle_infer.py", 500), TestFile("test_ebnf_constrained.py"), - TestFile("test_fa3.py", 500), + TestFile("test_fa3.py", 400), TestFile("test_fp8_kernel.py", 8), TestFile("test_embedding_openai_server.py", 36), TestFile("test_hidden_states.py", 55), @@ -92,7 +92,7 @@ suites = { TestFile("test_verl_engine.py", 100), ], "per-commit-8-gpu": [ - TestFile("test_local_attn.py", 100), + TestFile("test_local_attn.py", 250), ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 833bb3e6d..1dd3fc65c 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -3,7 +3,6 @@ import unittest from types import SimpleNamespace import requests -import torch from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k @@ -14,6 +13,7 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + CustomTestCase, popen_launch_server, ) @@ -47,9 +47,8 @@ if OFFLINE_MODE: # Default server arguments shared across all tests DEFAULT_SERVER_ARGS = [ "--trust-remote-code", - "--enable-torch-compile", "--cuda-graph-max-bs", - "2", + "4", "--attention-backend", "fa3", ] @@ -60,7 +59,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") -class BaseFlashAttentionTest(unittest.TestCase): +class BaseFlashAttentionTest(CustomTestCase): """Base class for testing FlashAttention3.""" model = DEFAULT_MODEL_NAME_FOR_TEST @@ -78,13 +77,13 @@ class BaseFlashAttentionTest(unittest.TestCase): def setUpClass(cls): # disable deep gemm precompile to make launch server faster # please don't do this if you want to make your inference workload faster - os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False" + os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=cls.get_server_args(), - env=os.environ, ) @classmethod @@ -92,6 +91,8 @@ class BaseFlashAttentionTest(unittest.TestCase): kill_process_tree(cls.process.pid) def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( num_shots=4, num_questions=100, @@ -102,7 +103,7 @@ class BaseFlashAttentionTest(unittest.TestCase): data_path=GSM_DATASET_PATH, ) metrics = run_eval_few_shot_gsm8k(args) - print(metrics) + print(f"{metrics=}") # Use the appropriate metric key based on the test class metric_key = "accuracy" @@ -192,60 +193,6 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): return args -class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): - """Test FlashAttention3 with speculative decode enabled, topk > 1""" - - model = DEFAULT_MODEL_NAME_FOR_TEST - - @classmethod - def get_server_args(cls): - args = super().get_server_args() - args.extend( - [ - "--cuda-graph-max-bs", - "2", - "--speculative-algorithm", - "EAGLE3", - "--speculative-draft", - DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, - "--speculative-num-steps", - "5", - "--speculative-eagle-topk", - "4", - "--speculative-num-draft-tokens", - "8", - "--dtype", - "float16", - ] - ) - return args - - def test_gsm8k(self): - """ - Override the test_gsm8k to further test for average speculative accept length. - """ - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=GSM_DATASET_PATH, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.8) - - class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model""" diff --git a/test/srt/test_local_attn.py b/test/srt/test_local_attn.py index 392bb0f39..923ffdd5e 100644 --- a/test/srt/test_local_attn.py +++ b/test/srt/test_local_attn.py @@ -10,12 +10,13 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + CustomTestCase, popen_launch_server, ) @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") -class TestFlashAttention3LocalAttn(unittest.TestCase): +class TestFlashAttention3LocalAttn(CustomTestCase): model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION base_url = DEFAULT_URL_FOR_TEST accuracy_threshold = 0.90 @@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase): @classmethod def get_server_args(cls): return [ - "--trust-remote-code", "--cuda-graph-max-bs", "2", "--attention-backend", @@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase): @classmethod def setUpClass(cls): - # disable deep gemm precompile to make launch server faster - # please don't do this if you want to make your inference workload faster cls.process = popen_launch_server( cls.model, cls.base_url, @@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase): kill_process_tree(cls.process.pid) def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( num_shots=4, num_questions=100,