enable llama3.1-8B on xpu (#9434)
This commit is contained in:
@@ -1,12 +1,20 @@
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_xpu,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_cpu = is_cpu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_npu = is_npu()
|
||||
_is_xpu = is_xpu()
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
@@ -88,5 +96,7 @@ class CustomOp(nn.Module):
|
||||
return self.forward_cpu
|
||||
elif _is_npu:
|
||||
return self.forward_npu
|
||||
elif _is_xpu:
|
||||
return self.forward_xpu
|
||||
else:
|
||||
return self.forward_native
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_xpu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from sglang.utils import resolve_obj_by_qualname
|
||||
@@ -44,8 +45,9 @@ _is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_hip = is_hip()
|
||||
_is_xpu = is_xpu()
|
||||
|
||||
if _is_cuda:
|
||||
if _is_cuda or _is_xpu:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
elif _is_hip:
|
||||
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
||||
@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if _is_cpu_amx_available:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||
return out
|
||||
else:
|
||||
@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
silu_and_mul(x, out)
|
||||
return out
|
||||
|
||||
|
||||
class GeluAndMul(CustomOp):
|
||||
def __init__(self, approximate="tanh"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
@@ -103,6 +106,16 @@ class GeluAndMul(CustomOp):
|
||||
raise RuntimeError("GeluAndMul only support tanh or none")
|
||||
return out
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y_npu, gelu_npu = torch_npu.npu_geglu(
|
||||
x,
|
||||
@@ -242,7 +255,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||
if not (
|
||||
_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
|
||||
):
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_xpu,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
@@ -37,6 +38,7 @@ _is_npu = is_npu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_xpu = is_xpu()
|
||||
|
||||
if _is_cuda:
|
||||
from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm
|
||||
@@ -327,7 +329,9 @@ class Gemma3RMSNorm(CustomOp):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
if not (
|
||||
_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
|
||||
):
|
||||
logger.info(
|
||||
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
||||
)
|
||||
|
||||
@@ -8,10 +8,11 @@ import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
||||
from sglang.srt.utils import is_npu
|
||||
from sglang.srt.utils import is_npu, is_xpu
|
||||
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
_is_xpu = is_xpu()
|
||||
if not (_is_npu or _is_xpu):
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_lf_pf,
|
||||
|
||||
Reference in New Issue
Block a user