From 1ebe1d6de5e0ce082e0be059c222baf0c5ee340a Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 1 Feb 2025 01:36:50 +0800 Subject: [PATCH] Optimize MoE topk with torch compile (#3236) --- python/sglang/srt/layers/moe/topk.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 527a7d499..dc53e4445 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,6 +17,8 @@ from typing import Callable, Optional import torch import torch.nn.functional as F +from sglang.srt.utils import get_compiler_backend + def fused_topk_native( hidden_states: torch.Tensor, @@ -74,6 +76,7 @@ def fused_topk( # This is used by the Deepseek-V2 model +@torch.compile(dynamic=True, backend=get_compiler_backend()) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -108,6 +111,7 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor,