Update python API of activation, topk, norm and rope and remove vllm dependency (#6614)
Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: sdp <sdp@gnr799219.jf.intel.com>
This commit is contained in:
@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
|
||||
topk_ids_logical_to_physical,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_compiler_backend,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import moe_fused_gate
|
||||
@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
def fused_topk_torch_native(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
@@ -61,6 +69,20 @@ def fused_topk_native(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_topk_cpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
return torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
|
||||
|
||||
# This is used by the Deepseek V2/V3/R1 series models
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def grouped_topk(
|
||||
def grouped_topk_gpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
@@ -171,6 +193,32 @@ def grouped_topk(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def grouped_topk_cpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert expert_location_dispatch_info is None
|
||||
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
num_token_non_padded,
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
|
||||
return topk_ids
|
||||
|
||||
|
||||
def biased_grouped_topk(
|
||||
def biased_grouped_topk_gpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
@@ -322,6 +370,45 @@ def biased_grouped_topk(
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_cpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
compiled: bool = True,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert expert_location_dispatch_info is None
|
||||
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
num_token_non_padded,
|
||||
)
|
||||
|
||||
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
biased_grouped_topk = biased_grouped_topk_cpu
|
||||
grouped_topk = grouped_topk_cpu
|
||||
fused_topk_native = fused_topk_cpu
|
||||
else:
|
||||
biased_grouped_topk = biased_grouped_topk_gpu
|
||||
grouped_topk = grouped_topk_gpu
|
||||
fused_topk_native = fused_topk_torch_native
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user