diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index a02d6e2d5..c061b6226 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -31,3 +31,6 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12 # For lmms_evals evaluating MMMU git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git pip install -e lmms-eval/ + +# Install FlashMLA for attention backend tests +pip install git+https://github.com/deepseek-ai/FlashMLA.git diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4afb76e00..c3f494a3c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -51,6 +51,7 @@ suites = { TestFile("test_mla_int8_deepseek_v3.py", 389), TestFile("test_mla_flashinfer.py", 395), TestFile("test_mla_fp8.py", 153), + TestFile("test_flash_mla_attention_backend.py", 300), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 216), TestFile("test_openai_server.py", 149), diff --git a/test/srt/test_flash_mla_attention_backend.py b/test/srt/test_flash_mla_attention_backend.py new file mode 100644 index 000000000..8d895d2eb --- /dev/null +++ b/test/srt/test_flash_mla_attention_backend.py @@ -0,0 +1,64 @@ +""" +Usage: +python3 -m unittest test_flash_mla_attention_backend.TestFlashMLAAttnBackend.test_mmlu +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + is_in_ci, + popen_launch_server, + run_bench_one_batch, +) + + +class TestFlashMLAAttnBackend(unittest.TestCase): + def test_latency(self): + output_throughput = run_bench_one_batch( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + [ + "--attention-backend", + "flashmla", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "16", + "--trust-remote-code", + ], + ) + + if is_in_ci(): + self.assertGreater(output_throughput, 153) + + def test_mmlu(self): + model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--attention-backend", "flashmla", "--trust-remote-code"], + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.2) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main()