Update python API of activation, topk, norm and rope and remove vllm dependency (#6614)

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: sdp <sdp@gnr799219.jf.intel.com>
This commit is contained in:
YanbingJiang
2025-06-18 13:11:50 +08:00
committed by GitHub
parent e56685ac1b
commit 094c116f7d
23 changed files with 270 additions and 56 deletions

View File

@@ -20,12 +20,21 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import (
@@ -122,6 +131,23 @@ class RMSNorm(CustomOp):
else:
return x, residual
def forward_cpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if _is_cpu_amx_available:
if residual is not None:
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
return torch.ops.sgl_kernel.rmsnorm_cpu(
x, self.weight.data, self.variance_epsilon
)
else:
return self.forward_native(x, residual)
class GemmaRMSNorm(CustomOp):
def __init__(
@@ -188,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not (_is_cuda or _is_hip or _is_npu):
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
)