Init attention backend for Intel XPU (#10656)

Co-authored-by: guangyey <guangye.yu@intel.com>
Co-authored-by: DiweiSun <105627594+DiweiSun@users.noreply.github.com>
This commit is contained in:
Meng, Hengyu
2025-10-21 11:41:28 +08:00
committed by GitHub
parent fb6cc7b000
commit b113c72e7a
18 changed files with 1210 additions and 26 deletions

View File

@@ -1,5 +1,3 @@
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.html install vllm
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
@@ -17,6 +15,10 @@ classifiers = [
]
dependencies = [
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"sgl-kernel @ git+https://github.com/sgl-project/sgl-kernel-xpu.git",
"IPython",
"aiohttp",
"anthropic>=0.20.0",
@@ -61,7 +63,7 @@ dependencies = [
"transformers==4.57.1",
"uvicorn",
"uvloop",
"xgrammar==0.1.25",
# "xgrammar==0.1.24", , xgrammar depends on CUDA PyTorch and Triton only
"grpcio==1.75.1", # keep it align with compile_proto.py
"grpcio-tools==1.75.1", # keep it align with compile_proto.py
"grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py

View File

@@ -272,7 +272,7 @@ def prepare_synthetic_inputs_for_latency_test(
def extend(reqs, model_runner):
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
dummy_tree_cache = SimpleNamespace(
page_size=1,
page_size=model_runner.server_args.page_size,
device=model_runner.device,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
)

View File

@@ -50,11 +50,13 @@ from sglang.srt.utils import (
is_hip,
is_npu,
is_shm_available,
is_xpu,
supports_custom_op,
)
_is_npu = is_npu()
_is_cpu = is_cpu()
_is_xpu = is_xpu()
_supports_custom_op = supports_custom_op()
@@ -694,7 +696,7 @@ class GroupCoordinator:
)
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if _is_npu or not _supports_custom_op:
if _is_npu or _is_xpu or not _supports_custom_op:
self._all_gather_into_tensor(output, input)
else:
torch.ops.sglang.reg_all_gather_into_tensor(
@@ -1298,7 +1300,7 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not _is_npu,
use_pynccl=not (_is_npu or _is_xpu),
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem=use_symm_mem_allreduce,

View File

@@ -217,3 +217,10 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
)
return full_attn_backend
@register_attention_backend("intel_xpu")
def create_intel_xpu_backend(runner):
from sglang.srt.layers.attention.xpu_backend import XPUAttentionBackend
return XPUAttentionBackend(runner)

View File

@@ -12,6 +12,8 @@ import triton
import triton.language as tl
from einops import rearrange
from sglang.srt.utils import device_context
def rms_norm_ref(
x,
@@ -157,7 +159,7 @@ def _layer_norm_fwd(
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.get_device_module(x.device).device(x.device.index):
with device_context(x.device):
_layer_norm_fwd_1pass_kernel[grid](
x,
out,

File diff suppressed because it is too large Load Diff

View File

@@ -42,7 +42,7 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_xpu = is_xpu()
if _is_cuda:
if _is_cuda or _is_xpu:
# if _is_flashinfer_available:
# from flashinfer.norm import fused_add_rmsnorm
# else:
@@ -52,13 +52,6 @@ if _is_cuda:
gemma_rmsnorm,
rmsnorm,
)
elif _is_xpu:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm

View File

@@ -39,10 +39,11 @@ if TYPE_CHECKING:
CombineInput,
)
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import is_cuda, is_hip, is_xpu
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_xpu = is_xpu()
if _is_cuda:
from sgl_kernel import (
awq_dequantize,
@@ -58,8 +59,12 @@ elif _is_hip:
)
warnings.warn(f"HIP does not support fused_marlin_moe currently.")
elif _is_xpu:
from sgl_kernel import awq_dequantize
warnings.warn(f"XPU does not support fused_marlin_moe currently.")
else:
warnings.warn(f"Only CUDA and HIP support AWQ currently.")
warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.")
logger = logging.getLogger(__name__)

View File

@@ -115,7 +115,7 @@ class RotaryEmbedding(CustomOp):
if dtype == torch.float32 or (
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
and not (_is_cpu and _is_cpu_amx_available)
and not _is_xpu
and not (_is_xpu)
):
from vllm._custom_ops import rotary_embedding
@@ -302,6 +302,7 @@ class RotaryEmbedding(CustomOp):
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: make a wrapper, and XPU will implement this kernel later.
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
return self.forward_native(positions, query, key, offsets)

View File

@@ -142,6 +142,7 @@ from sglang.srt.utils import (
monkey_patch_vllm_gguf_config,
set_cuda_arch,
slow_rank_detector,
xpu_has_xmx_support,
)
from sglang.srt.utils.offloader import (
create_offloader_from_server_args,
@@ -195,6 +196,7 @@ def add_chunked_prefix_cache_attention_backend(backend_name):
_is_hip = is_hip()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_xpu_xmx_available = xpu_has_xmx_support()
# Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
@@ -505,6 +507,16 @@ class ModelRunner:
)
server_args.attention_backend = "torch_native"
if (
server_args.attention_backend == "intel_xpu"
and server_args.device == "xpu"
and not _is_xpu_xmx_available
):
logger.info(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
server_args.attention_backend = "triton"
if server_args.prefill_attention_backend is not None and (
server_args.prefill_attention_backend
== server_args.decode_attention_backend

View File

@@ -114,6 +114,7 @@ ATTENTION_BACKEND_CHOICES = [
# Other platforms
"intel_amx",
"ascend",
"intel_xpu",
]
LORA_BACKEND_CHOICES = ["triton", "csgmv"]
@@ -1098,6 +1099,12 @@ class ServerArgs:
self.enable_mixed_chunk = False
self.disable_radix_cache = True
if self.attention_backend == "intel_xpu":
if self.page_size not in [32, 64, 128]:
logger.warning(
f"Intel XPU attention backend only supports page_size of 32, 64 or 128, changing page_size from {self.page_size} to 128."
)
self.page_size = 128
if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4":
raise ValueError(
"FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead."

View File

@@ -163,6 +163,20 @@ def _check(cc_major):
) >= (12, 3)
@contextmanager
def device_context(device: torch.device):
if device.type == "cpu" and is_cpu():
with torch.device("cpu"):
yield
else:
module = torch.get_device_module(device)
if module is not None:
with module.device(device.index):
yield
else:
raise ValueError(f"Unknown device module: {device}")
is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9)
@@ -263,6 +277,14 @@ def use_intel_amx_backend(layer):
return getattr(layer, "use_intel_amx_backend", False)
def xpu_has_xmx_support():
# TODO: update with XPU capalibity query
if is_xpu():
# currently only PVC/LNL/BMG supports F64, so we only support these now
return torch.xpu.get_device_properties().has_fp64
return False
def is_flashinfer_available():
"""
Check whether flashinfer is available.