use topk_softmax with sgl-kernel (#4439)

This commit is contained in:
Yineng Zhang
2025-03-14 15:59:06 -07:00
committed by GitHub
parent e73167ade3
commit ad1ae7f7cd
18 changed files with 48 additions and 35 deletions

View File

@@ -43,7 +43,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.0.5",
"sgl-kernel==0.0.5.post1",
"flashinfer_python==0.2.3",
"torch==2.5.1",
"vllm>=0.6.4.post1,<=0.7.2",

View File

@@ -17,7 +17,9 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.utils import get_compiler_backend
from sglang.srt.utils import get_compiler_backend, is_cuda
_is_cuda = is_cuda()
def fused_topk_native(
@@ -47,7 +49,10 @@ def fused_topk(
topk: int,
renormalize: bool,
):
from vllm import _custom_ops as ops
if _is_cuda:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -61,12 +66,20 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
if _is_cuda:
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
else:
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies
if renormalize: