Optimize MoE topk with torch compile (#3236)
This commit is contained in:
@@ -17,6 +17,8 @@ from typing import Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
|
|
||||||
def fused_topk_native(
|
def fused_topk_native(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -74,6 +76,7 @@ def fused_topk(
|
|||||||
|
|
||||||
|
|
||||||
# This is used by the Deepseek-V2 model
|
# This is used by the Deepseek-V2 model
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def grouped_topk(
|
def grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@@ -108,6 +111,7 @@ def grouped_topk(
|
|||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def biased_grouped_topk(
|
def biased_grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user