diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 8c662b5cc..ea3c06e6d 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -1,12 +1,20 @@ from torch import nn -from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu +from sglang.srt.utils import ( + cpu_has_amx_support, + is_cpu, + is_cuda, + is_hip, + is_npu, + is_xpu, +) _is_cuda = is_cuda() _is_hip = is_hip() _is_cpu = is_cpu() _is_cpu_amx_available = cpu_has_amx_support() _is_npu = is_npu() +_is_xpu = is_xpu() class CustomOp(nn.Module): @@ -88,5 +96,7 @@ class CustomOp(nn.Module): return self.forward_cpu elif _is_npu: return self.forward_npu + elif _is_xpu: + return self.forward_xpu else: return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 4c7620669..37832a3f7 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -35,6 +35,7 @@ from sglang.srt.utils import ( is_cuda, is_hip, is_npu, + is_xpu, set_weight_attrs, ) from sglang.utils import resolve_obj_by_qualname @@ -44,8 +45,9 @@ _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() _is_hip = is_hip() +_is_xpu = is_xpu() -if _is_cuda: +if _is_cuda or _is_xpu: from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul elif _is_hip: from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul @@ -70,8 +72,6 @@ class SiluAndMul(CustomOp): def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: if _is_cpu_amx_available: - d = x.shape[-1] // 2 - output_shape = x.shape[:-1] + (d,) out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) return out else: @@ -81,17 +81,20 @@ class SiluAndMul(CustomOp): out = torch_npu.npu_swiglu(x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + silu_and_mul(x, out) + return out + class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() self.approximate = approximate - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] - - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) @@ -103,6 +106,16 @@ class GeluAndMul(CustomOp): raise RuntimeError("GeluAndMul only support tanh or none") return out + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self._forward_impl(x) + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self._forward_impl(x) + def forward_npu(self, x: torch.Tensor) -> torch.Tensor: y_npu, gelu_npu = torch_npu.npu_geglu( x, @@ -242,7 +255,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return nn.Identity() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): +if not ( + _is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu +): logger.info( "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 81ec3693a..5d941a489 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -28,6 +28,7 @@ from sglang.srt.utils import ( is_cuda, is_hip, is_npu, + is_xpu, supports_custom_op, ) @@ -37,6 +38,7 @@ _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_xpu = is_xpu() if _is_cuda: from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm @@ -327,7 +329,9 @@ class Gemma3RMSNorm(CustomOp): return f"{tuple(self.weight.shape)}, eps={self.eps}" -if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not ( + _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu +): logger.info( "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 9b9553238..15b5efe5a 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -8,10 +8,11 @@ import psutil import torch from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool -from sglang.srt.utils import is_npu +from sglang.srt.utils import is_npu, is_xpu _is_npu = is_npu() -if not _is_npu: +_is_xpu = is_xpu() +if not (_is_npu or _is_xpu): from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, transfer_kv_all_layer_lf_pf,