From c3eac1b010b3da3086457e40af555690da0787a6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 14 Nov 2024 01:30:24 -0800 Subject: [PATCH] Fix torch.compile for MoE (#2033) --- python/sglang/srt/layers/fused_moe/patch.py | 6 +- python/sglang/test/test_utils.py | 5 +- test/srt/run_suite.py | 1 + test/srt/test_data_parallelism.py | 2 +- test/srt/test_double_sparsity.py | 2 +- test/srt/test_eval_accuracy_mini.py | 2 +- test/srt/test_retract_decode.py | 2 +- test/srt/test_torch_compile.py | 6 +- test/srt/test_torch_compile_moe.py | 73 +++++++++++++++++++++ test/srt/test_triton_attention_backend.py | 2 +- 10 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 test/srt/test_torch_compile_moe.py diff --git a/python/sglang/srt/layers/fused_moe/patch.py b/python/sglang/srt/layers/fused_moe/patch.py index 65fcd7877..6e64c89aa 100644 --- a/python/sglang/srt/layers/fused_moe/patch.py +++ b/python/sglang/srt/layers/fused_moe/patch.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch from torch.nn import functional as F @@ -98,7 +98,9 @@ def fused_moe_forward_native( renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: + assert custom_routing_function is None topk_weights, topk_ids = select_experts_native( hidden_states=x, router_logits=router_logits, @@ -114,4 +116,4 @@ def fused_moe_forward_native( x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) - return torch.einsum("tai,ta -> ti", expert_outs, topk_weights) + return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index f136a4d1b..364ccefc5 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -28,8 +28,9 @@ from sglang.utils import get_exception_traceback DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8" DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" -DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" +DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" +DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 @@ -740,7 +741,7 @@ def run_mmlu_test( try: metrics = run_eval(args) print(f"{metrics=}") - assert metrics["score"] >= 0.65 + self.assertGreaterEqual(metrics["score"], 0.65) finally: pass diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 697fb8d21..7f343a15a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -27,6 +27,7 @@ suites = { "test_srt_engine.py", "test_srt_endpoint.py", "test_torch_compile.py", + "test_torch_compile_moe.py", "test_torchao.py", "test_triton_attention_kernels.py", "test_triton_attention_backend.py", diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 4efccd36c..393e44468 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -40,7 +40,7 @@ class TestDataParallelism(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.65 + self.assertGreaterEqual(metrics["score"], 0.65) def test_update_weight(self): response = requests.post( diff --git a/test/srt/test_double_sparsity.py b/test/srt/test_double_sparsity.py index 14ee4de3c..1a35280a0 100644 --- a/test/srt/test_double_sparsity.py +++ b/test/srt/test_double_sparsity.py @@ -55,7 +55,7 @@ class TestDoubleSparsity(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.65 + self.assertGreaterEqual(metrics["score"], 0.65) if __name__ == "__main__": diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index fa15c1181..a718feff7 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -35,7 +35,7 @@ class TestEvalAccuracyMini(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.65 + self.assertGreaterEqual(metrics["score"], 0.65) if __name__ == "__main__": diff --git a/test/srt/test_retract_decode.py b/test/srt/test_retract_decode.py index 20352e729..834c51f9d 100644 --- a/test/srt/test_retract_decode.py +++ b/test/srt/test_retract_decode.py @@ -34,7 +34,7 @@ class TestRetractDecode(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.65 + self.assertGreaterEqual(metrics["score"], 0.65) if __name__ == "__main__": diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index f5f4b602e..ddb92a57f 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.65 + self.assertGreaterEqual(metrics["score"], 0.65) def run_decode(self, max_new_tokens): response = requests.post( @@ -49,8 +49,8 @@ class TestTorchCompile(unittest.TestCase): "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, + "ignore_eos": True, }, - "ignore_eos": True, }, ) return response.json() @@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase): print(res["text"]) throughput = max_tokens / (tok - tic) print(f"Throughput: {throughput} tokens/s") - assert throughput >= 152 + self.assertGreaterEqual(throughput, 152) if __name__ == "__main__": diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py new file mode 100644 index 000000000..934ef3499 --- /dev/null +++ b/test/srt/test_torch_compile_moe.py @@ -0,0 +1,73 @@ +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestTorchCompile(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-torch-compile", "--torch-compile-max-bs", "1"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.50) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + }, + ) + return response.json() + + def test_throughput(self): + import time + + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(f"{res=}") + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + self.assertGreaterEqual(throughput, 290) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_triton_attention_backend.py b/test/srt/test_triton_attention_backend.py index 2a6fe17bd..2687bb435 100644 --- a/test/srt/test_triton_attention_backend.py +++ b/test/srt/test_triton_attention_backend.py @@ -48,7 +48,7 @@ class TestTritonAttnBackend(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.65 + self.assertGreaterEqual(metrics["score"], 0.65) finally: kill_child_process(process.pid, include_self=True)