use topk_softmax with sgl-kernel (#4439)
This commit is contained in:
@@ -43,7 +43,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.0.5",
|
||||
"sgl-kernel==0.0.5.post1",
|
||||
"flashinfer_python==0.2.3",
|
||||
"torch==2.5.1",
|
||||
"vllm>=0.6.4.post1,<=0.7.2",
|
||||
|
||||
@@ -17,7 +17,9 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
@@ -47,7 +49,10 @@ def fused_topk(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
from vllm import _custom_ops as ops
|
||||
if _is_cuda:
|
||||
from sgl_kernel import topk_softmax
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -61,12 +66,20 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
if _is_cuda:
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
else:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
del token_expert_indicies
|
||||
|
||||
if renormalize:
|
||||
|
||||
Reference in New Issue
Block a user