Optimize MoE topk with torch compile (#3236)

This commit is contained in:
Ke Bao
2025-02-01 01:36:50 +08:00
committed by GitHub
parent 7811bfdaa7
commit 1ebe1d6de5

View File

@@ -17,6 +17,8 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.utils import get_compiler_backend
def fused_topk_native(
hidden_states: torch.Tensor,
@@ -74,6 +76,7 @@ def fused_topk(
# This is used by the Deepseek-V2 model
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -108,6 +111,7 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,