From a68cb201dd5f4ae6155b324d22054bbb0de15fba Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 21 Sep 2024 10:25:20 +0800 Subject: [PATCH] Fix triton head num (#1482) --- .github/workflows/pr-test.yml | 6 +++ python/sglang/srt/layers/attention_backend.py | 4 +- python/sglang/test/test_utils.py | 1 + test/srt/test_mla.py | 44 +++++++++++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 test/srt/test_mla.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 8d4f839e8..a66ec9f71 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -233,6 +233,12 @@ jobs: run: | cd test/srt python3 test_moe_eval_accuracy_large.py + + - name: Evaluate MLA Accuracy (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_mla.TestMLA.test_mmlu finish: needs: [ diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 73bdf512b..71dbfe0e3 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -346,7 +346,9 @@ class TritonAttnBackend(AttentionBackend): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd - self.num_head = model_runner.model_config.num_attention_heads + self.num_head = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): self.reduce_dtype = torch.float32 diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 172c2fcc3..a3de4b619 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -25,6 +25,7 @@ from sglang.utils import get_exception_traceback DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8" DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" +DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Meta-Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Meta-Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct" diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py new file mode 100644 index 000000000..0ec245a33 --- /dev/null +++ b/test/srt/test_mla.py @@ -0,0 +1,44 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestMLA(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--tp", "2", "--trust-remote-code"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main()