diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 507590025..03406ef86 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -90,11 +90,11 @@ jobs: - name: MLA TEST timeout-minutes: 20 run: | - docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py TestMLA + docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py finish: needs: [ - accuracy-test-1-gpu-amd + accuracy-test-1-gpu-amd, mla-test-1-gpu-amd ] runs-on: ubuntu-latest steps: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 754fe9a79..ebab2bf68 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -24,6 +24,7 @@ suites = { "test_gguf.py", "test_input_embeddings.py", "test_mla.py", + "test_mla_deepseek_v3.py", "test_mla_flashinfer.py", "test_mla_fp8.py", "test_json_constrained.py", diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index b2a831f99..b1f9d090d 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -1,11 +1,7 @@ import unittest from types import SimpleNamespace -import requests -import torch - from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -56,101 +52,5 @@ class TestMLA(unittest.TestCase): self.assertGreater(metrics["score"], 0.8) -class TestDeepseekV3(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = "lmsys/sglang-ci-dsv3-test" - cls.base_url = DEFAULT_URL_FOR_TEST - other_args = ["--trust-remote-code"] - if torch.cuda.is_available() and torch.version.cuda: - other_args.extend(["--enable-torch-compile", "--cuda-graph-max-bs", "2"]) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.62) - - -class TestDeepseekV3MTP(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = "lmsys/sglang-ci-dsv3-test" - cls.base_url = DEFAULT_URL_FOR_TEST - other_args = ["--trust-remote-code"] - if torch.cuda.is_available() and torch.version.cuda: - other_args.extend( - [ - "--cuda-graph-max-bs", - "2", - "--disable-radix", - "--enable-torch-compile", - "--torch-compile-max-bs", - "1", - "--speculative-algorithm", - "EAGLE", - "--speculative-draft", - "lmsys/sglang-ci-dsv3-test-NextN", - "--speculative-num-steps", - "2", - "--speculative-eagle-topk", - "4", - "--speculative-num-draft-tokens", - "4", - ] - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 2.5) - - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py new file mode 100644 index 000000000..ba43c2ba1 --- /dev/null +++ b/test/srt/test_mla_deepseek_v3.py @@ -0,0 +1,113 @@ +import unittest +from types import SimpleNamespace + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestDeepseekV3(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend(["--enable-torch-compile", "--cuda-graph-max-bs", "2"]) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +class TestDeepseekV3MTP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--disable-radix", + "--enable-torch-compile", + "--torch-compile-max-bs", + "1", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + "lmsys/sglang-ci-dsv3-test-NextN", + "--speculative-num-steps", + "2", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "4", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.5) + + +if __name__ == "__main__": + unittest.main()