diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 3bdf7c7c2..f7ca5e203 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union import numpy as np import torch +import triton +import triton.language as tl from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -64,6 +66,9 @@ class FlashAttentionMetadata: local_attn_metadata: Optional[LocalAttentionMetadata] = None + # For sliding window attention topk>1 spec decoding + swa_spec_metadata: Optional[FlashAttentionMetadata] = None + # Copied from: # https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py @@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend): else None ) + # For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata. + # We use `layer.sliding_window_size` to decide whether to use SWA for each layer. + self.sliding_window_size = model_runner.sliding_window_size + self.has_swa = ( + self.sliding_window_size is not None and self.sliding_window_size > -1 + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata hence all layers in the forward pass can reuse it.""" metadata = FlashAttentionMetadata() @@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend): (1, 0), ) self.forward_metadata_spec_decode_expand = metadata_expand + + if self.has_swa: + self._init_sliding_window_attn_spec_metadata( + metadata, metadata_expand + ) + elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() @@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend): # Calculate window size (can be moved to metadata if layer properties don't change) # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # here is two side inclusive - window_size = ( - (layer.sliding_window_size, 0) - if layer.sliding_window_size is not None and layer.sliding_window_size > -1 - else (-1, -1) + is_swa = ( + layer.sliding_window_size is not None and layer.sliding_window_size > -1 ) + window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1) k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # has corresponding quantization method so that layer.k_scale is not None, @@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend): ) # We do cascade attention for Target Verify with topk > 1 + # We don't use cascade attention for Sliding Window Attention: + # - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes. + # - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it. use_cascade_attn = ( - forward_batch.forward_mode.is_target_verify() and self.topk > 1 + forward_batch.forward_mode.is_target_verify() + and self.topk > 1 + and not is_swa ) # For fa3 interface version compatibility, we put new fields into conditional keyword args @@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend): cu_seqlens_q = local_metadata.local_query_start_loc cache_seqlens = local_metadata.local_seqused_k max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len + elif is_swa and metadata.swa_spec_metadata is not None: + swa_spec_metadata = metadata.swa_spec_metadata + page_table = swa_spec_metadata.page_table + cu_seqlens_q = swa_spec_metadata.cu_seqlens_q + cache_seqlens = swa_spec_metadata.cache_seqlens_int32 + max_seqlen_q = swa_spec_metadata.max_seq_len_q + cu_seqlens_k = swa_spec_metadata.cu_seqlens_k else: page_table = metadata.page_table cu_seqlens_q = metadata.cu_seqlens_q cache_seqlens = metadata.cache_seqlens_int32 max_seqlen_q = metadata.max_seq_len_q - max_seqlen_k = metadata.max_seq_len_k cu_seqlens_k = metadata.cu_seqlens_k # Use Flash Attention for prefill @@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend): ), } + if self.has_swa: + self.target_verify_metadata_topk_swa = { + "cache_seqlens": torch.zeros( + max_bs * self.speculative_num_draft_tokens, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs * self.speculative_num_draft_tokens + 1, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_q": torch.arange( + 0, + max_bs * self.speculative_num_draft_tokens + 1, + dtype=torch.int32, + device=self.device, + ), + "page_table": torch.zeros( + max_bs * self.speculative_num_draft_tokens, + self.max_context_len, + dtype=torch.int32, + device=self.device, + ), + } + self.encoder_metadata = { "encoder_page_table": torch.zeros( max_bs, @@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend): self.target_verify_metadata_topk_normal[bs] = metadata self.target_verify_metadata_topk_expand[bs] = metadata_expand + + if self.has_swa: + metadata_swa = FlashAttentionMetadata() + metadata_swa.cache_seqlens_int32 = ( + self.target_verify_metadata_topk_swa["cache_seqlens"][ + : bs * self.speculative_num_draft_tokens + ] + ) + metadata_swa.max_seq_len_q = 1 + metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[ + "cu_seqlens_q" + ][: bs * self.speculative_num_draft_tokens + 1] + metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[ + "cu_seqlens_k" + ][: bs * self.speculative_num_draft_tokens + 1] + + metadata_swa.page_table = self.target_verify_metadata_topk_swa[ + "page_table" + ][: bs * self.speculative_num_draft_tokens] + self.target_verify_metadata_topk_swa[bs] = metadata_swa + metadata.swa_spec_metadata = metadata_swa + elif forward_mode.is_draft_extend(): metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][ :bs @@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend): ) ) + if self.has_swa: + metadata_swa = self.target_verify_metadata_topk_swa[bs] + self._init_sliding_window_attn_spec_metadata( + metadata, metadata_expand, metadata_swa + ) + elif forward_mode.is_draft_extend(): metadata = self.draft_extend_metadata[bs] metadata.cache_seqlens_int32.copy_(seq_lens) @@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend): lam.local_max_query_len = int(seqlens_q_local_np.max()) lam.local_max_seq_len = int(seqlens_k_local_np.max()) + def _init_sliding_window_attn_spec_metadata( + self, + metadata: FlashAttentionMetadata, + metadata_expand: FlashAttentionMetadata, + metadata_swa: Optional[FlashAttentionMetadata] = None, + ): + # TODO: support page_size > 1 for swa spec + assert ( + self.page_size == 1 + ), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention" + + cache_seqlens_int32 = ( + metadata.cache_seqlens_int32.repeat_interleave( + self.speculative_num_draft_tokens + ) + + metadata_expand.cache_seqlens_int32 + ) + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0) + ) + bs = cache_seqlens_int32.shape[0] + page_table = ( + metadata.page_table.new_zeros( + (bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1]) + ) + if metadata_swa is None + else metadata_swa.page_table + ) + + prepare_swa_spec_page_table_triton( + page_table, + metadata.page_table, + metadata_expand.page_table, + metadata.cache_seqlens_int32, + metadata_expand.cache_seqlens_int32, + self.speculative_num_draft_tokens, + ) + + if metadata_swa is None: + metadata_swa = FlashAttentionMetadata() + metadata_swa.max_seq_len_q = 1 + metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q + metadata_swa.cache_seqlens_int32 = cache_seqlens_int32 + metadata_swa.cu_seqlens_k = cu_seqlens_k + metadata_swa.page_table = page_table + else: + metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32) + metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k) + + metadata.swa_spec_metadata = metadata_swa + + +@triton.jit +def _prepare_swa_spec_page_table_kernel( + dst_ptr, + src_a_ptr, + src_b_ptr, + seq_len_a_ptr, + seq_len_b_ptr, + dst_stride_m, + dst_stride_n, + a_stride_m, + a_stride_n, + b_stride_m, + b_stride_n, + LEN_A: tl.constexpr, + LEN_B: tl.constexpr, + REPEAT_STEP: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + idx_a = pid_m // REPEAT_STEP + idx_b = pid_m + seq_len_a = tl.load(seq_len_a_ptr + idx_a) + seq_len_b = tl.load(seq_len_b_ptr + idx_b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + total_len = seq_len_a + seq_len_b + + if pid_n * BLOCK_N >= total_len: + return + + mask = offs_n < total_len + dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n + + if (pid_n + 1) * BLOCK_N < seq_len_a: + a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n + a_mask = mask & (offs_n < LEN_A) + val = tl.load(a_ptr, mask=a_mask, other=0) + tl.store(dst, val, mask=mask) + elif pid_n * BLOCK_N >= seq_len_a: + offs_b = offs_n - seq_len_a + b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n + b_mask = mask & (offs_b < LEN_B) + val = tl.load(b_ptr, mask=b_mask, other=0) + tl.store(dst, val, mask=mask) + else: + # mixed part + a_offs = offs_n + a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A) + a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n + a_val = tl.load(a_ptr, mask=a_mask, other=0) + + b_offs = offs_n - seq_len_a + b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B) + b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n + b_val = tl.load(b_ptr, mask=b_mask, other=0) + + result = tl.where(offs_n < seq_len_a, a_val, b_val) + tl.store(dst, result, mask=mask) + + +def prepare_swa_spec_page_table_triton( + page_table_dst: torch.Tensor, + page_table_a: torch.Tensor, + page_table_b: torch.Tensor, # expand page table + seq_len_a: torch.Tensor, + seq_len_b: torch.Tensor, # expand seq lens + speculative_num_draft_tokens: int, +): + # concat page_table and expand page_table by kv seq length + bs = seq_len_a.numel() + bs_expand = seq_len_b.numel() + assert bs_expand == bs * speculative_num_draft_tokens + + LEN_A = page_table_a.shape[1] + LEN_B = page_table_b.shape[1] + LEN_OUT = LEN_A + LEN_B + REPEAT_STEP = speculative_num_draft_tokens + BLOCK_N = 256 + + grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N)) + _prepare_swa_spec_page_table_kernel[grid]( + page_table_dst, + page_table_a, + page_table_b, + seq_len_a, + seq_len_b, + page_table_dst.stride(0), + page_table_dst.stride(1), + page_table_a.stride(0), + page_table_a.stride(1), + page_table_b.stride(0), + page_table_b.stride(1), + LEN_A=LEN_A, + LEN_B=LEN_B, + REPEAT_STEP=REPEAT_STEP, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + class FlashAttentionMultiStepBackend: