From 07ec07ad1fa59e0f07a4fcd1b1f324123c2e2bd4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 3 Dec 2024 01:58:25 -0800 Subject: [PATCH] Improve torch compile for fused moe (#2327) --- .../benchmark_torch_compile_fused_moe.py | 7 +++-- python/sglang/srt/layers/fused_moe_patch.py | 31 ++++++++++++------- .../srt/model_executor/cuda_graph_runner.py | 23 +++++++++----- .../sglang/srt/model_executor/model_runner.py | 2 +- test/srt/test_srt_engine.py | 2 +- test/srt/test_torch_compile_moe.py | 4 +-- 6 files changed, 45 insertions(+), 24 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 1f54f9f9f..1bd6eec16 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -6,6 +6,7 @@ from torch.nn import functional as F from transformers import AutoConfig from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton +from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config def get_model_config(model_name: str, tp_size: int): @@ -64,7 +65,7 @@ def fused_topk_native( return topk_weights, topk_ids -@torch.compile +@torch.compile(dynamic=False) def fused_moe_torch( x, w1, @@ -88,7 +89,8 @@ def fused_moe_torch( w13_weights = w1[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = w2[topk_ids] - x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) @@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): print(f"benchmark {provider} with batch_size={batch_size}") torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) + set_torch_compile_config() num_tokens = batch_size num_experts = model_config["num_experts"] diff --git a/python/sglang/srt/layers/fused_moe_patch.py b/python/sglang/srt/layers/fused_moe_patch.py index 400ca03c4..baca25811 100644 --- a/python/sglang/srt/layers/fused_moe_patch.py +++ b/python/sglang/srt/layers/fused_moe_patch.py @@ -105,20 +105,29 @@ def fused_moe_forward_native( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - assert custom_routing_function is None - topk_weights, topk_ids = select_experts_native( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - ) + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + top_k, + renormalize, + num_expert_group, + topk_group, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + x, router_logits, top_k, renormalize + ) + w13_weights = layer.w13_weight[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = layer.w2_weight[topk_ids] - x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 84f6825c3..dd26a77ad 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -def _to_torch(model: torch.nn.Module, reverse: bool = False): +def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: @@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): else: # NOTE: Temporarily workaround MoE if "FusedMoE" in sub.__class__.__name__: - sub._forward_method = fused_moe_forward_native + if batch_size == 1: + # The performance of torch.compile on this layer is not always good when bs > 1, + # so we decide to skip it for now. + sub._forward_method = fused_moe_forward_native else: sub._forward_method = sub.forward_native setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse) + _to_torch(sub, reverse, batch_size) @contextmanager def patch_model( - model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" + model: torch.nn.Module, + enable_compile: bool, + batch_size: int, + tp_group: "GroupCoordinator", ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: if enable_compile: - _to_torch(model) + _to_torch(model, reverse=False, batch_size=batch_size) monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. @@ -70,13 +76,15 @@ def patch_model( # even with ENABLE_INTRA_NODE_COMM=1. # tp_group.ca_comm = None yield torch.compile( - torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs" + torch.no_grad()(model.forward), + mode="max-autotune-no-cudagraphs", + dynamic=False, ) else: yield model.forward finally: if enable_compile: - _to_torch(model, reverse=True) + _to_torch(model, reverse=True, batch_size=batch_size) monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm @@ -237,6 +245,7 @@ class CudaGraphRunner: with patch_model( self.model_runner.model, bs in self.compile_bs, + bs, self.model_runner.tp_group, ) as forward: ( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 74a7d1fc5..fafb8783e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -622,7 +622,7 @@ class ModelRunner: tic = time.time() logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) - logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s") + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index a985c8dda..7479b6468 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase): ) bench_args = BenchArgs(num_prompts=10) result = throughput_test(server_args=server_args, bench_args=bench_args) - self.assertGreater(result["total_throughput"], 3500) + self.assertGreater(result["total_throughput"], 3000) if __name__ == "__main__": diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py index 89d4ed6bd..fb78dd7f4 100644 --- a/test/srt/test_torch_compile_moe.py +++ b/test/srt/test_torch_compile_moe.py @@ -14,7 +14,7 @@ from sglang.test.test_utils import ( ) -class TestTorchCompile(unittest.TestCase): +class TestTorchCompileMoe(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST @@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile", "--torch-compile-max-bs", "1"], + other_args=["--enable-torch-compile", "--torch-compile-max-bs", "8"], ) @classmethod