Optimize MoE topk with torch compile (#3236)
This commit is contained in:
@@ -17,6 +17,8 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -74,6 +76,7 @@ def fused_topk(
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 model
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -108,6 +111,7 @@ def grouped_topk(
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def biased_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user