diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 6fbfd3038..2bce60768 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -161,6 +161,12 @@ jobs: cd test/srt python3 test_moe_ep.py + - name: Test torch compile (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_mla_tp.py + performance-test-1-gpu-part-1: needs: filter if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && diff --git a/test/srt/test_mla_tp.py b/test/srt/test_mla_tp.py new file mode 100644 index 000000000..a5f2420cc --- /dev/null +++ b/test/srt/test_mla_tp.py @@ -0,0 +1,65 @@ +import unittest +from types import SimpleNamespace + +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestDeepseekTP2(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "sgl-project/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + ["--tp", "2", "--enable-torch-compile", "--cuda-graph-max-bs", "2"] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreater(metrics["accuracy"], 0.62) + + def test_gsm8k_bs1(self): + # test torch compile accuracy for bs=1 + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=10, + max_new_tokens=512, + parallel=1, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreater(metrics["accuracy"], 0.62) + + +if __name__ == "__main__": + unittest.main()