diff --git a/Makefile b/Makefile index 96d7df32b..a8a1fbd88 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,9 @@ FILES_TO_UPDATE = docker/Dockerfile.rocm \ docs/get_started/install.md \ docs/platforms/amd_gpu.md \ docs/platforms/ascend_npu.md \ - benchmark/deepseek_v3/README.md + docs/platforms/cpu_server.md \ + docs/platforms/xpu.md \ + benchmark/deepseek_v3/README.md update: ## Update version numbers across project files. Usage: make update @if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \ diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index bd32551f5..a50de5bf4 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -48,7 +48,7 @@ RUN --mount=type=secret,id=github_token \ cd /home/sdp && \ . /home/sdp/miniforge3/bin/activate && \ conda activate py${PYTHON_VERSION} && \ - pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu + pip3 install torch==2.8.0+xpu torchao torchvision torchaudio pytorch-triton-xpu==3.4.0 --index-url https://download.pytorch.org/whl/xpu RUN --mount=type=secret,id=github_token \ cd /home/sdp && \ @@ -59,13 +59,8 @@ RUN --mount=type=secret,id=github_token \ cd sglang && cd python && \ cp pyproject_xpu.toml pyproject.toml && \ pip install . && \ - echo "Cloning ${SG_LANG_KERNEL_REPO} from ${SG_LANG_KERNEL_BRANCH}" && \ - git clone --branch ${SG_LANG_KERNEL_BRANCH} --single-branch ${SG_LANG_KERNEL_REPO} && \ - cd sgl-kernel-xpu && \ - pip install -v . && \ + pip install xgrammar --no-deps && \ pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops --root-user-action=ignore && \ - pip uninstall pytorch-triton-xpu -y && \ - pip install --pre pytorch-triton-xpu --index-url https://download.pytorch.org/whl/xpu && \ conda install libsqlite=3.48.0 -y && \ # Add environment setup commands to .bashrc again (in case it was overwritten) echo ". /home/sdp/miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /home/sdp" >> /home/sdp/.bashrc diff --git a/docs/advanced_features/attention_backend.md b/docs/advanced_features/attention_backend.md index 0024aece6..ec223add1 100644 --- a/docs/advanced_features/attention_backend.md +++ b/docs/advanced_features/attention_backend.md @@ -26,6 +26,7 @@ The support matrix is split into two parts: MHA (standard attention) and MLA (mu | **AITER (ROCm)** | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | | **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Intel XPU** | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ### MLA Backends @@ -190,6 +191,13 @@ python3 -m sglang.launch_server \ --attention-backend ascend ``` +- Intel XPU +```bash +python3 -m sglang.launch_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --attention-backend intel_xpu +``` + - Wave ```bash python3 -m sglang.launch_server \ diff --git a/docs/index.rst b/docs/index.rst index f34f13e37..293f75984 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -75,6 +75,7 @@ Its core features include: platforms/tpu.md platforms/nvidia_jetson.md platforms/ascend_npu.md + platforms/xpu.md .. toctree:: :maxdepth: 1 diff --git a/docs/platforms/xpu.md b/docs/platforms/xpu.md new file mode 100644 index 000000000..4b5ecc9e1 --- /dev/null +++ b/docs/platforms/xpu.md @@ -0,0 +1,92 @@ +# XPU + +The document addresses how to set up the [SGLang](https://github.com/sgl-project/sglang) environment and run LLM inference on Intel GPU, [see more context about Intel GPU support within PyTorch ecosystem](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html). + +Specifically, SGLang is optimized for [Intel® Arc™ Pro B-Series Graphics](https://www.intel.com/content/www/us/en/ark/products/series/242616/intel-arc-pro-b-series-graphics.html) and [ +Intel® Arc™ B-Series Graphics](https://www.intel.com/content/www/us/en/ark/products/series/240391/intel-arc-b-series-graphics.html). + +## Optimized Model List + +A list of LLMs have been optimized on Intel GPU, and more are on the way: + +| Model Name | BF16 | +|:---:|:---:| +| Llama-3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | +| Llama-3.1-8B | [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) | +| Qwen2.5-1.5B | [Qwen/Qwen2.5-1.5B](https://huggingface.co/Qwen/Qwen2.5-1.5B) | + +**Note:** The model identifiers listed in the table above +have been verified on [Intel® Arc™ B580 Graphics](https://www.intel.com/content/www/us/en/products/sku/241598/intel-arc-b580-graphics/specifications.html). + +## Installation + +### Install From Source + +Currently SGLang XPU only supports installation from source. Please refer to ["Getting Started on Intel GPU"](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) to install XPU dependency. + +```bash +# Create and activate a conda environment +conda create -n sgl-xpu python=3.12 -y +conda activate sgl-xpu + +# Set PyTorch XPU as primary pip install channel to avoid installing the larger CUDA-enabled version and prevent potential runtime issues. +pip3 install torch==2.8.0+xpu torchao torchvision torchaudio pytorch-triton-xpu==3.4.0 --index-url https://download.pytorch.org/whl/xpu +pip3 install xgrammar --no-deps # xgrammar will introduce CUDA-enabled triton which might conflict with XPU + +# Clone the SGLang code +git clone https://github.com/sgl-project/sglang.git +cd sglang +git checkout + +# Use dedicated toml file +cd python +cp pyproject_xpu.toml pyproject.toml +# Install SGLang dependent libs, and build SGLang main package +pip install --upgrade pip setuptools +pip install -v . +``` + +### Install Using Docker + +The docker for XPU is under active development. Please stay tuned. + +## Launch of the Serving Engine + +Example command to launch SGLang serving: + +```bash +python -m sglang.launch_server \ + --model \ + --trust-remote-code \ + --disable-overlap-schedule \ + --device xpu \ + --host 0.0.0.0 \ + --tp 2 \ # using multi GPUs + --attention-backend intel_xpu \ # using intel optimized XPU attention backend + --page-size \ # intel_xpu attention backend supports [32, 64, 128] +``` + +## Benchmarking with Requests + +You can benchmark the performance via the `bench_serving` script. +Run the command in another terminal. + +```bash +python -m sglang.bench_serving \ + --dataset-name random \ + --random-input-len 1024 \ + --random-output-len 1024 \ + --num-prompts 1 \ + --request-rate inf \ + --random-range-ratio 1.0 +``` + +The detail explanations of the parameters can be looked up by the command: + +```bash +python -m sglang.bench_serving -h +``` + +Additionally, the requests can be formed with +[OpenAI Completions API](https://docs.sglang.ai/basic_usage/openai_api_completions.html) +and sent via the command line (e.g. using `curl`) or via your own script. diff --git a/python/pyproject_xpu.toml b/python/pyproject_xpu.toml index 26557be88..ce4dc79e6 100644 --- a/python/pyproject_xpu.toml +++ b/python/pyproject_xpu.toml @@ -1,5 +1,3 @@ -# xpu is not enabled in public vllm and torch whl, -# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.html install vllm [build-system] requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" @@ -17,6 +15,10 @@ classifiers = [ ] dependencies = [ + "torch==2.8.0", + "torchaudio==2.8.0", + "torchvision", + "sgl-kernel @ git+https://github.com/sgl-project/sgl-kernel-xpu.git", "IPython", "aiohttp", "anthropic>=0.20.0", @@ -61,7 +63,7 @@ dependencies = [ "transformers==4.57.1", "uvicorn", "uvloop", - "xgrammar==0.1.25", + # "xgrammar==0.1.24", , xgrammar depends on CUDA PyTorch and Triton only "grpcio==1.75.1", # keep it align with compile_proto.py "grpcio-tools==1.75.1", # keep it align with compile_proto.py "grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 5604495e3..86189fa9e 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -272,7 +272,7 @@ def prepare_synthetic_inputs_for_latency_test( def extend(reqs, model_runner): # Create dummy tree_cache for benchmarks (no prefix caching, just allocation) dummy_tree_cache = SimpleNamespace( - page_size=1, + page_size=model_runner.server_args.page_size, device=model_runner.device, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, ) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 15ab65589..e0d533479 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -50,11 +50,13 @@ from sglang.srt.utils import ( is_hip, is_npu, is_shm_available, + is_xpu, supports_custom_op, ) _is_npu = is_npu() _is_cpu = is_cpu() +_is_xpu = is_xpu() _supports_custom_op = supports_custom_op() @@ -694,7 +696,7 @@ class GroupCoordinator: ) def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): - if _is_npu or not _supports_custom_op: + if _is_npu or _is_xpu or not _supports_custom_op: self._all_gather_into_tensor(output, input) else: torch.ops.sglang.reg_all_gather_into_tensor( @@ -1298,7 +1300,7 @@ def init_model_parallel_group( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=not _is_npu, + use_pynccl=not (_is_npu or _is_xpu), use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, use_torch_symm_mem=use_symm_mem_allreduce, diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 5cffd3b97..89922b062 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -217,3 +217,10 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ) return full_attn_backend + + +@register_attention_backend("intel_xpu") +def create_intel_xpu_backend(runner): + from sglang.srt.layers.attention.xpu_backend import XPUAttentionBackend + + return XPUAttentionBackend(runner) diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index b7dd39b12..edae2c52a 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -12,6 +12,8 @@ import triton import triton.language as tl from einops import rearrange +from sglang.srt.utils import device_context + def rms_norm_ref( x, @@ -157,7 +159,7 @@ def _layer_norm_fwd( # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) - with torch.get_device_module(x.device).device(x.device.index): + with device_context(x.device): _layer_norm_fwd_1pass_kernel[grid]( x, out, diff --git a/python/sglang/srt/layers/attention/xpu_backend.py b/python/sglang/srt/layers/attention/xpu_backend.py new file mode 100644 index 000000000..5ab4a160c --- /dev/null +++ b/python/sglang/srt/layers/attention/xpu_backend.py @@ -0,0 +1,1028 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionMetadata, + make_local_attention_virtual_batches, + merge_state_v2_wrapper, + prepare_swa_spec_page_table_triton, +) +from sglang.srt.managers.schedule_batch import get_global_server_args +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + +from sgl_kernel import merge_state_v2 +from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + + +class XPUAttentionBackend(AttentionBackend): + """XPU FlashAttention backend, currently based on FlashAttentionBackend, will be refactored later. + + TODO: + - Prefill and Decode disaggregation, currently only chunked prefill is supported + - Speculative Decoding support + - XPU Graph support, see https://github.com/pytorch/pytorch/issues/162143 + - MLA support + """ + + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + speculative_step_id=0, + topk=0, + speculative_num_steps=0, + ): + super().__init__() + + assert not ( + model_runner.sliding_window_size is not None + and model_runner.model_config.is_encoder_decoder + ), "Sliding window and cross attention are not supported together" + + self.forward_metadata: FlashAttentionMetadata = None + # extra metadata for handling speculative decoding topk > 1, extended draft decode and verify + self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None + self.max_context_len = model_runner.model_config.context_len + self.device = model_runner.device + self.decode_cuda_graph_metadata = {} + self.target_verify_metadata = {} + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.kv_cache_dtype = model_runner.kv_cache_dtype + self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype + self.page_size = model_runner.page_size + self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA + assert ( + self.use_mla is False + ), "XPUAttentionBackend doesn't support MLA yet, please use --attention-backend triton instead." + self.skip_prefill = skip_prefill + self.is_hybrid = model_runner.is_hybrid + if self.is_hybrid: + self.full_to_swa_index_mapping = ( + model_runner.token_to_kv_pool.full_to_swa_index_mapping + ) + self.topk = model_runner.server_args.speculative_eagle_topk or 0 + self.speculative_num_steps = speculative_num_steps + self.speculative_num_draft_tokens = ( + model_runner.server_args.speculative_num_draft_tokens + ) + self.speculative_step_id = speculative_step_id + + # Local attention settings + self.attention_chunk_size = ( + model_runner.attention_chunk_size + if hasattr(model_runner, "attention_chunk_size") + else None + ) + + # For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata. + # We use `layer.sliding_window_size` to decide whether to use SWA for each layer. + self.sliding_window_size = model_runner.sliding_window_size + self.has_swa = ( + self.sliding_window_size is not None and self.sliding_window_size > -1 + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Initialize forward metadata hence all layers in the forward pass can reuse it.""" + metadata = FlashAttentionMetadata() + seqlens_in_batch = forward_batch.seq_lens + batch_size = forward_batch.batch_size + device = seqlens_in_batch.device + + if forward_batch.forward_mode.is_decode_or_idle(): + # Draft Decode + if forward_batch.spec_info is not None: + assert ( + False + ), "XPUAttentionBackend doesn't support speculative decoding yet, please use --attention-backend triton instead." + if self.topk <= 1: + metadata.cache_seqlens_int32 = ( + seqlens_in_batch + (self.speculative_step_id + 1) + ).to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.speculative_step_id + 1 + ) + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + else: + metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32) + metadata.max_seq_len_q = self.topk + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_q = torch.arange( + 0, + batch_size * self.topk + 1, + step=self.topk, + dtype=torch.int32, + device=device, + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + + metadata_expand = FlashAttentionMetadata() + decode_length = self.speculative_step_id + 1 + metadata_expand.cache_seqlens_int32 = torch.full( + (seqlens_in_batch.numel() * self.topk,), + decode_length, + device=device, + dtype=torch.int32, + ) + metadata_expand.max_seq_len_q = 1 + metadata_expand.cu_seqlens_q = torch.arange( + 0, + metadata_expand.cache_seqlens_int32.numel() + 1, + dtype=torch.int32, + device=device, + ) + metadata_expand.cu_seqlens_k = torch.arange( + 0, + metadata_expand.cache_seqlens_int32.numel() * decode_length + 1, + step=decode_length, + dtype=torch.int32, + device=device, + ) + # shape: [bs, num_steps, topk] -> [bs x topk, num_steps] + cache_loc = forward_batch.out_cache_loc.view( + -1, self.speculative_num_steps + ) + metadata_expand.page_table = ( + cache_loc[:, :decode_length].contiguous().to(torch.int32) + ) + self.forward_metadata_spec_decode_expand = metadata_expand + else: + # Normal Decode + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + # TODO: we need to test this part for llama 4 eagle case + self._init_local_attn_metadata(forward_batch, metadata, device) + elif forward_batch.forward_mode.is_target_verify(): + if self.topk <= 1: + metadata.cache_seqlens_int32 = ( + forward_batch.seq_lens + self.speculative_num_draft_tokens + ).to(torch.int32) + metadata.max_seq_len_q = self.speculative_num_draft_tokens + metadata.max_seq_len_k = ( + forward_batch.seq_lens_cpu.max().item() + + self.speculative_num_draft_tokens + ) + metadata.cu_seqlens_q = torch.arange( + 0, + batch_size * self.speculative_num_draft_tokens + 1, + self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + + self._init_local_attn_metadata(forward_batch, metadata, device) + else: + metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32) + metadata.max_seq_len_q = self.speculative_num_draft_tokens + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_q = torch.arange( + 0, + batch_size * self.speculative_num_draft_tokens + 1, + step=self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + + metadata_expand = FlashAttentionMetadata() + + metadata_expand.max_seq_len_q = 1 + metadata_expand.cu_seqlens_q = torch.arange( + 0, + forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens + + 1, + dtype=torch.int32, + device=device, + ) + + # create expand page table + offsets = torch.arange( + self.speculative_num_draft_tokens, device=device + ).unsqueeze( + 0 + ) # shape: (1, self.speculative_num_draft_tokens) + cols = offsets.expand( + forward_batch.seq_lens.numel(), -1 + ) + forward_batch.seq_lens.unsqueeze(1) + cum_len = torch.nn.functional.pad( + torch.cumsum( + ( + forward_batch.seq_lens + self.speculative_num_draft_tokens + ).repeat_interleave(self.speculative_num_draft_tokens), + dim=0, + ), + (1, 0), + )[:-1] + mask_extraction_indices = ( + cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0) + + cum_len[:, None] + ).view(1, -1) + mask = forward_batch.spec_info.custom_mask[ + mask_extraction_indices + ].view( + -1, self.speculative_num_draft_tokens + ) # (bsz * draft_num, draft_num) + + # shift table indices to avoid padding + # non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0], + # [8, 9, 10], [1, 1, 0], + # [8, 9, 10]] [1, 0, 1]] + # if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10], + # [8, 9, 0], [8, 9, 10], + # [8, 0, 10]] [8, 10, 9]] + # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row + col_indices = offsets.expand( + mask.shape[0], self.speculative_num_draft_tokens + ) + # Build keys: if an entry is valid (mask==True), keep its original index; + # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries. + keys = torch.where( + mask, col_indices, col_indices + self.speculative_num_draft_tokens + ) + _, sort_order = torch.sort(keys, dim=1) + non_masked_page_table = ( + forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + .gather(1, cols) + .repeat_interleave(self.speculative_num_draft_tokens, dim=0) + ) # (bsz, draft_num) + metadata_expand.page_table = non_masked_page_table.gather(1, sort_order) + metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32) + metadata_expand.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + self.forward_metadata_spec_decode_expand = metadata_expand + + if self.has_swa: + self._init_sliding_window_attn_spec_metadata( + metadata, metadata_expand + ) + + elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + + if ( + any(forward_batch.extend_prefix_lens_cpu) + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + ): + extend_seq_lens = forward_batch.extend_seq_lens + metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) + metadata.cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) + else: + metadata.max_seq_len_q = metadata.max_seq_len_k + metadata.cu_seqlens_q = metadata.cu_seqlens_k + + # Setup local attention if enabled + if forward_batch.forward_mode == ForwardMode.EXTEND: + self._init_local_attn_metadata(forward_batch, metadata, device) + + # Encoder metadata for cross attention + if forward_batch.encoder_lens is not None: + assert ( + forward_batch.encoder_lens.numel() == 1 + ), "Only encoder size 1 is supported for now" + + metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32) + metadata.encoder_cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32), + (1, 0), + ) + metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item() + metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k + ] + + # Currently only support forward_batch.encoder_lens.numel() == 1 + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, + metadata.encoder_max_seq_len_k : ( + metadata.encoder_max_seq_len_k + metadata.max_seq_len_k + ), + ] + + # Convert the page table to a strided format which is needed by FA3 API + if self.page_size > 1: + self.strided_indices = torch.arange( + 0, metadata.page_table.shape[1], self.page_size, device=self.device + ) + metadata.page_table = ( + metadata.page_table[:, self.strided_indices] // self.page_size + ) + + self.forward_metadata = metadata + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + ): + if k is not None: + assert v is not None + if save_kv_cache: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + if not self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + else: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + + # Use precomputed metadata across all layers + metadata = self.forward_metadata + + # Calculate window size (can be moved to metadata if layer properties don't change) + # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 + # here is two side inclusive + is_swa = ( + layer.sliding_window_size is not None and layer.sliding_window_size > -1 + ) + window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1) + + # currently no FP8 KV cache supported + k_descale, v_descale = None, None + # # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention + # # has corresponding quantization method so that layer.k_scale is not None, + # # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. + # if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256: + # if layer.k_scale is not None: + # descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + # k_descale = layer.k_scale.expand(descale_shape) + # v_descale = layer.v_scale.expand(descale_shape) + # q = q.to(self.kv_cache_dtype) + # q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None + # k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None + causal = not layer.is_cross_attention + + # Check if we should use local attention + use_local_attn = ( + self.attention_chunk_size is not None + and metadata.local_attn_metadata is not None + and (hasattr(layer, "use_irope") and layer.use_irope) + ) + + # We do cascade attention for Target Verify with topk > 1 + # We don't use cascade attention for Sliding Window Attention: + # - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes. + # - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it. + use_cascade_attn = ( + forward_batch.forward_mode.is_target_verify() + and self.topk > 1 + and not is_swa + ) + + # For fa3 interface version compatibility, we put new fields into conditional keyword args + kwargs = {} + if sinks is not None: + kwargs["sinks"] = sinks + + # Get the appropriate page table based on whether we're using local attention + if use_local_attn: + local_metadata = metadata.local_attn_metadata + page_table = local_metadata.local_block_table + cu_seqlens_q = local_metadata.local_query_start_loc + cache_seqlens = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + elif is_swa and metadata.swa_spec_metadata is not None: + swa_spec_metadata = metadata.swa_spec_metadata + page_table = swa_spec_metadata.page_table + cu_seqlens_q = swa_spec_metadata.cu_seqlens_q + cache_seqlens = swa_spec_metadata.cache_seqlens_int32 + max_seqlen_q = swa_spec_metadata.max_seq_len_q + cu_seqlens_k = swa_spec_metadata.cu_seqlens_k + else: + page_table = metadata.page_table + cu_seqlens_q = metadata.cu_seqlens_q + cache_seqlens = metadata.cache_seqlens_int32 + max_seqlen_q = metadata.max_seq_len_q + cu_seqlens_k = metadata.cu_seqlens_k + + # Use Flash Attention for prefill + if not self.use_mla: + # Do multi-head attention + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + if layer.is_cross_attention: + page_table = metadata.encoder_page_table + cache_seqlens = metadata.encoder_lens_int32 + cu_seqlens_k = metadata.encoder_cu_seqlens_k + window_size = (-1, -1) + + result = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + **kwargs, + ) + + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + **kwargs, + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + else: + if ( + forward_batch.attn_attend_prefix_cache is not None + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + # Do multi-head attention with chunked prefix cache + if forward_batch.attn_attend_prefix_cache: + assert not get_global_server_args().disable_chunked_prefix_cache + # MHA for chunked prefix kv cache when running model with MLA + assert forward_batch.prefix_chunk_idx is not None + assert forward_batch.prefix_chunk_cu_seq_lens is not None + assert forward_batch.prefix_chunk_max_seq_lens is not None + + chunk_idx = forward_batch.prefix_chunk_idx + assert chunk_idx >= 0 + + assert forward_batch.mha_return_lse + output = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], + softmax_scale=layer.scaling, + causal=False, + return_softmax_lse=True, + ) + else: + # MHA for extend part of sequence without attending prefix kv cache + output = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=metadata.cu_seqlens_q, + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=metadata.max_seq_len_q, + softmax_scale=layer.scaling, + causal=True, + return_softmax_lse=forward_batch.mha_return_lse, + ) + if forward_batch.mha_return_lse: + output, lse, *rest = output + lse = torch.transpose(lse, 0, 1).contiguous() + return output, lse + return output + else: + # Do absorbed multi-latent attention + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ).to(q.dtype) + k_rope = kv_cache[:, :, layer.v_head_dim :] + c_kv = kv_cache[:, :, : layer.v_head_dim] + k_rope_cache = k_rope.view( + -1, + self.page_size, + layer.tp_k_head_num, + layer.head_dim - layer.v_head_dim, + ) + c_kv_cache = c_kv.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + + result = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + ) + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + ) + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if k is not None: + assert v is not None + if save_kv_cache: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + if not self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + else: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + + # Use precomputed metadata across all layers + metadata = self.forward_metadata + local_attn_metadata = getattr(metadata, "local_attn_metadata", None) + use_local_attn = ( + self.attention_chunk_size is not None + and local_attn_metadata is not None + and (hasattr(layer, "use_irope") and layer.use_irope) + ) + + # When Spec Decode enabled, forward_decode would be called with two mode: + # 1. DRAFT_DECODE: we enable cascade attention when top_k > 1 + # 2. IDLE: we don’t need cascade attention, spec_info will be none in this case + use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1 + + # Calculate window size (can be moved to metadata if layer properties don't change) + # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 + # here is two side inclusive + window_size = ( + (layer.sliding_window_size, 0) + if layer.sliding_window_size is not None and layer.sliding_window_size > -1 + else (-1, -1) + ) + causal = not layer.is_cross_attention + + # For fa3 interface version compatibility, we put new fields into conditional keyword args + kwargs = {} + if sinks is not None: + kwargs["sinks"] = sinks + + k_descale, v_descale = None, None + # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention + # has corresponding quantization method so that layer.k_scale is not None, + # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. + if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256: + if layer.k_scale is not None: + descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + k_descale = layer.k_scale.expand(descale_shape) + v_descale = layer.v_scale.expand(descale_shape) + q = q.to(self.kv_cache_dtype) + q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None + k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None + if not self.use_mla: + # Do multi-head attention + + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + if layer.is_cross_attention: + # Always use non-chunked logic for cross-attention + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=metadata.encoder_page_table, + cache_seqlens=metadata.encoder_lens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.encoder_cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=layer.scaling, + causal=False, + window_size=(-1, -1), + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + **kwargs, + ) + elif use_local_attn: + # Use chunked (local) attention batching for self-attention + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=local_attn_metadata.local_block_table, + cache_seqlens=local_attn_metadata.local_seqused_k, + cu_seqlens_q=local_attn_metadata.local_query_start_loc, + cu_seqlens_k_new=None, + max_seqlen_q=local_attn_metadata.local_max_query_len, + softmax_scale=layer.scaling, + causal=True, + window_size=(-1, -1), + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + **kwargs, + ) + else: + page_table = metadata.page_table + cache_seqlens = metadata.cache_seqlens_int32 + cu_seqlens_k = metadata.cu_seqlens_k + max_seqlen_q = metadata.max_seq_len_q + q_reshaped = q.contiguous().view( + -1, layer.tp_q_head_num, layer.head_dim + ) + + # Default: single-token self-attention + result = flash_attn_with_kvcache( + q=q_reshaped, + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + **kwargs, + ) + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q_reshaped, + k_cache=key_cache, + v_cache=value_cache, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + **kwargs, + ) + ) + o, _ = merge_state_v2( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + else: + # Do absorbed multi-latent attention + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( + q.dtype + ) + k_rope = kv_cache[:, :, layer.v_head_dim :] + c_kv = kv_cache[:, :, : layer.v_head_dim] + k_rope_cache = k_rope.view( + -1, + self.page_size, + layer.tp_k_head_num, + layer.head_dim - layer.v_head_dim, + ) + c_kv_cache = c_kv.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + max_seqlen_q = metadata.max_seq_len_q + + result = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=metadata.page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states + ) + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + ) + o, _ = merge_state_v2( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for sequence length in CUDA graph.""" + return 1 + + def _init_local_attn_metadata( + self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device + ): + """Centralized utility to initialize local_attn_metadata if chunked attention is enabled.""" + if self.attention_chunk_size is None: + metadata.local_attn_metadata = None + return + + cu_seqlens_q = metadata.cu_seqlens_q + cache_seqlens_int32 = metadata.cache_seqlens_int32 + if self.is_hybrid: + page_table = self.full_to_swa_index_mapping[metadata.page_table].to( + torch.int32 + ) + else: + page_table = metadata.page_table + if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None: + metadata.local_attn_metadata = None + return + + cu_seqlens_q_np = cu_seqlens_q.cpu().numpy() + seq_lens_np = cache_seqlens_int32.cpu().numpy() + ( + seqlens_q_local_np, + cu_seqlens_q_local_np, + seqlens_k_local_np, + block_table_local, + ) = make_local_attention_virtual_batches( + self.attention_chunk_size, + cu_seqlens_q_np, + seq_lens_np, + page_table, + self.page_size, + ) + + local_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device), + local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device), + local_block_table=block_table_local.to(device), + local_max_query_len=int(seqlens_q_local_np.max()), + local_max_seq_len=int(seqlens_k_local_np.max()), + ) + metadata.local_attn_metadata = local_metadata + + def _init_sliding_window_attn_spec_metadata( + self, + metadata: FlashAttentionMetadata, + metadata_expand: FlashAttentionMetadata, + metadata_swa: Optional[FlashAttentionMetadata] = None, + ): + # TODO: support page_size > 1 for swa spec + assert ( + self.page_size == 1 + ), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention" + + cache_seqlens_int32 = ( + metadata.cache_seqlens_int32.repeat_interleave( + self.speculative_num_draft_tokens + ) + + metadata_expand.cache_seqlens_int32 + ) + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0) + ) + bs = cache_seqlens_int32.shape[0] + page_table = ( + metadata.page_table.new_zeros( + (bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1]) + ) + if metadata_swa is None + else metadata_swa.page_table + ) + + prepare_swa_spec_page_table_triton( + page_table, + metadata.page_table, + metadata_expand.page_table, + metadata.cache_seqlens_int32, + metadata_expand.cache_seqlens_int32, + self.speculative_num_draft_tokens, + ) + + if metadata_swa is None: + metadata_swa = FlashAttentionMetadata() + metadata_swa.max_seq_len_q = 1 + metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q + metadata_swa.cache_seqlens_int32 = cache_seqlens_int32 + metadata_swa.cu_seqlens_k = cu_seqlens_k + metadata_swa.page_table = page_table + else: + metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32) + metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k) + + metadata.swa_spec_metadata = metadata_swa diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index c60314ad9..1a7263cd5 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -42,7 +42,7 @@ _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() _is_xpu = is_xpu() -if _is_cuda: +if _is_cuda or _is_xpu: # if _is_flashinfer_available: # from flashinfer.norm import fused_add_rmsnorm # else: @@ -52,13 +52,6 @@ if _is_cuda: gemma_rmsnorm, rmsnorm, ) -elif _is_xpu: - from sgl_kernel import ( - fused_add_rmsnorm, - gemma_fused_add_rmsnorm, - gemma_rmsnorm, - rmsnorm, - ) if _use_aiter: from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index d796008c8..5c195516c 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -39,10 +39,11 @@ if TYPE_CHECKING: CombineInput, ) -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import is_cuda, is_hip, is_xpu _is_cuda = is_cuda() _is_hip = is_hip() +_is_xpu = is_xpu() if _is_cuda: from sgl_kernel import ( awq_dequantize, @@ -58,8 +59,12 @@ elif _is_hip: ) warnings.warn(f"HIP does not support fused_marlin_moe currently.") +elif _is_xpu: + from sgl_kernel import awq_dequantize + + warnings.warn(f"XPU does not support fused_marlin_moe currently.") else: - warnings.warn(f"Only CUDA and HIP support AWQ currently.") + warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.") logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 83842b0cc..2c8181ebe 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -115,7 +115,7 @@ class RotaryEmbedding(CustomOp): if dtype == torch.float32 or ( (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) and not (_is_cpu and _is_cpu_amx_available) - and not _is_xpu + and not (_is_xpu) ): from vllm._custom_ops import rotary_embedding @@ -302,6 +302,7 @@ class RotaryEmbedding(CustomOp): offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: make a wrapper, and XPU will implement this kernel later. + self.cos_sin_cache = self.cos_sin_cache.to(query.device) return self.forward_native(positions, query, key, offsets) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0c6407130..18b7428d6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -142,6 +142,7 @@ from sglang.srt.utils import ( monkey_patch_vllm_gguf_config, set_cuda_arch, slow_rank_detector, + xpu_has_xmx_support, ) from sglang.srt.utils.offloader import ( create_offloader_from_server_args, @@ -195,6 +196,7 @@ def add_chunked_prefix_cache_attention_backend(backend_name): _is_hip = is_hip() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() +_is_xpu_xmx_available = xpu_has_xmx_support() # Use a small KV cache pool size for tests in CI SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) @@ -505,6 +507,16 @@ class ModelRunner: ) server_args.attention_backend = "torch_native" + if ( + server_args.attention_backend == "intel_xpu" + and server_args.device == "xpu" + and not _is_xpu_xmx_available + ): + logger.info( + "The current platform does not support Intel XMX, will fallback to triton backend." + ) + server_args.attention_backend = "triton" + if server_args.prefill_attention_backend is not None and ( server_args.prefill_attention_backend == server_args.decode_attention_backend diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 25cf28f31..2e432de1b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -114,6 +114,7 @@ ATTENTION_BACKEND_CHOICES = [ # Other platforms "intel_amx", "ascend", + "intel_xpu", ] LORA_BACKEND_CHOICES = ["triton", "csgmv"] @@ -1098,6 +1099,12 @@ class ServerArgs: self.enable_mixed_chunk = False self.disable_radix_cache = True + if self.attention_backend == "intel_xpu": + if self.page_size not in [32, 64, 128]: + logger.warning( + f"Intel XPU attention backend only supports page_size of 32, 64 or 128, changing page_size from {self.page_size} to 128." + ) + self.page_size = 128 if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4": raise ValueError( "FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead." diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 2264168e2..148a73bf8 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -163,6 +163,20 @@ def _check(cc_major): ) >= (12, 3) +@contextmanager +def device_context(device: torch.device): + if device.type == "cpu" and is_cpu(): + with torch.device("cpu"): + yield + else: + module = torch.get_device_module(device) + if module is not None: + with module.device(device.index): + yield + else: + raise ValueError(f"Unknown device module: {device}") + + is_ampere_with_cuda_12_3 = lambda: _check(8) is_hopper_with_cuda_12_3 = lambda: _check(9) @@ -263,6 +277,14 @@ def use_intel_amx_backend(layer): return getattr(layer, "use_intel_amx_backend", False) +def xpu_has_xmx_support(): + # TODO: update with XPU capalibity query + if is_xpu(): + # currently only PVC/LNL/BMG supports F64, so we only support these now + return torch.xpu.get_device_properties().has_fp64 + return False + + def is_flashinfer_available(): """ Check whether flashinfer is available. diff --git a/test/srt/xpu/test_intel_xpu_backend.py b/test/srt/xpu/test_intel_xpu_backend.py index 91ebd57a2..a0f301617 100644 --- a/test/srt/xpu/test_intel_xpu_backend.py +++ b/test/srt/xpu/test_intel_xpu_backend.py @@ -8,6 +8,7 @@ import unittest from functools import wraps from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, CustomTestCase, is_in_ci, @@ -55,6 +56,10 @@ class TestIntelXPUBackend(CustomTestCase): def test_latency_qwen_model(self): return DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN + @intel_xpu_benchmark(["--attention-backend", "intel_xpu", "--page-size", "128"]) + def test_attention_backend(self): + return DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE + if __name__ == "__main__": unittest.main()