Update python API of activation, topk, norm and rope and remove vllm dependency (#6614)

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: sdp <sdp@gnr799219.jf.intel.com>
This commit is contained in:
YanbingJiang
2025-06-18 13:11:50 +08:00
committed by GitHub
parent e56685ac1b
commit 094c116f7d
23 changed files with 270 additions and 56 deletions

View File

@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
topk_ids_logical_to_physical,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
from sglang.srt.utils import (
cpu_has_amx_support,
get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import moe_fused_gate
@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
def fused_topk_native(
def fused_topk_torch_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
@@ -61,6 +69,20 @@ def fused_topk_native(
return topk_weights, topk_ids
def fused_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
return torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
# This is used by the Deepseek V2/V3/R1 series models
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk(
def grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
@@ -171,6 +193,32 @@ def grouped_topk(
return topk_weights, topk_ids
def grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert expert_location_dispatch_info is None
return torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
def biased_grouped_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
return topk_ids
def biased_grouped_topk(
def biased_grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
@@ -322,6 +370,45 @@ def biased_grouped_topk(
)
def biased_grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert expert_location_dispatch_info is None
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
if _is_cpu and _is_cpu_amx_available:
biased_grouped_topk = biased_grouped_topk_cpu
grouped_topk = grouped_topk_cpu
fused_topk_native = fused_topk_cpu
else:
biased_grouped_topk = biased_grouped_topk_gpu
grouped_topk = grouped_topk_gpu
fused_topk_native = fused_topk_torch_native
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,