Fix FA3 swa spec verify topk>1 (#9658)
This commit is contained in:
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -64,6 +66,9 @@ class FlashAttentionMetadata:
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
# For sliding window attention topk>1 spec decoding
|
||||
swa_spec_metadata: Optional[FlashAttentionMetadata] = None
|
||||
|
||||
|
||||
# Copied from:
|
||||
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
||||
@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
else None
|
||||
)
|
||||
|
||||
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
|
||||
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
self.has_swa = (
|
||||
self.sliding_window_size is not None and self.sliding_window_size > -1
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
(1, 0),
|
||||
)
|
||||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||
|
||||
if self.has_swa:
|
||||
self._init_sliding_window_attn_spec_metadata(
|
||||
metadata, metadata_expand
|
||||
)
|
||||
|
||||
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||
# here is two side inclusive
|
||||
window_size = (
|
||||
(layer.sliding_window_size, 0)
|
||||
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
||||
else (-1, -1)
|
||||
is_swa = (
|
||||
layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
||||
)
|
||||
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
|
||||
k_descale, v_descale = None, None
|
||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||
# has corresponding quantization method so that layer.k_scale is not None,
|
||||
@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# We do cascade attention for Target Verify with topk > 1
|
||||
# We don't use cascade attention for Sliding Window Attention:
|
||||
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
|
||||
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
|
||||
use_cascade_attn = (
|
||||
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
and self.topk > 1
|
||||
and not is_swa
|
||||
)
|
||||
|
||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||
@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
cache_seqlens = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
elif is_swa and metadata.swa_spec_metadata is not None:
|
||||
swa_spec_metadata = metadata.swa_spec_metadata
|
||||
page_table = swa_spec_metadata.page_table
|
||||
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
|
||||
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
|
||||
max_seqlen_q = swa_spec_metadata.max_seq_len_q
|
||||
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
cu_seqlens_q = metadata.cu_seqlens_q
|
||||
cache_seqlens = metadata.cache_seqlens_int32
|
||||
max_seqlen_q = metadata.max_seq_len_q
|
||||
max_seqlen_k = metadata.max_seq_len_k
|
||||
cu_seqlens_k = metadata.cu_seqlens_k
|
||||
|
||||
# Use Flash Attention for prefill
|
||||
@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
}
|
||||
|
||||
if self.has_swa:
|
||||
self.target_verify_metadata_topk_swa = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs * self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs * self.speculative_num_draft_tokens + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0,
|
||||
max_bs * self.speculative_num_draft_tokens + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs * self.speculative_num_draft_tokens,
|
||||
self.max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
|
||||
self.encoder_metadata = {
|
||||
"encoder_page_table": torch.zeros(
|
||||
max_bs,
|
||||
@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.target_verify_metadata_topk_normal[bs] = metadata
|
||||
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
||||
|
||||
if self.has_swa:
|
||||
metadata_swa = FlashAttentionMetadata()
|
||||
metadata_swa.cache_seqlens_int32 = (
|
||||
self.target_verify_metadata_topk_swa["cache_seqlens"][
|
||||
: bs * self.speculative_num_draft_tokens
|
||||
]
|
||||
)
|
||||
metadata_swa.max_seq_len_q = 1
|
||||
metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
|
||||
"cu_seqlens_q"
|
||||
][: bs * self.speculative_num_draft_tokens + 1]
|
||||
metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
|
||||
"cu_seqlens_k"
|
||||
][: bs * self.speculative_num_draft_tokens + 1]
|
||||
|
||||
metadata_swa.page_table = self.target_verify_metadata_topk_swa[
|
||||
"page_table"
|
||||
][: bs * self.speculative_num_draft_tokens]
|
||||
self.target_verify_metadata_topk_swa[bs] = metadata_swa
|
||||
metadata.swa_spec_metadata = metadata_swa
|
||||
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
||||
:bs
|
||||
@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
)
|
||||
|
||||
if self.has_swa:
|
||||
metadata_swa = self.target_verify_metadata_topk_swa[bs]
|
||||
self._init_sliding_window_attn_spec_metadata(
|
||||
metadata, metadata_expand, metadata_swa
|
||||
)
|
||||
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata = self.draft_extend_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
||||
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
||||
|
||||
def _init_sliding_window_attn_spec_metadata(
|
||||
self,
|
||||
metadata: FlashAttentionMetadata,
|
||||
metadata_expand: FlashAttentionMetadata,
|
||||
metadata_swa: Optional[FlashAttentionMetadata] = None,
|
||||
):
|
||||
# TODO: support page_size > 1 for swa spec
|
||||
assert (
|
||||
self.page_size == 1
|
||||
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
|
||||
|
||||
cache_seqlens_int32 = (
|
||||
metadata.cache_seqlens_int32.repeat_interleave(
|
||||
self.speculative_num_draft_tokens
|
||||
)
|
||||
+ metadata_expand.cache_seqlens_int32
|
||||
)
|
||||
cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
bs = cache_seqlens_int32.shape[0]
|
||||
page_table = (
|
||||
metadata.page_table.new_zeros(
|
||||
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
|
||||
)
|
||||
if metadata_swa is None
|
||||
else metadata_swa.page_table
|
||||
)
|
||||
|
||||
prepare_swa_spec_page_table_triton(
|
||||
page_table,
|
||||
metadata.page_table,
|
||||
metadata_expand.page_table,
|
||||
metadata.cache_seqlens_int32,
|
||||
metadata_expand.cache_seqlens_int32,
|
||||
self.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
if metadata_swa is None:
|
||||
metadata_swa = FlashAttentionMetadata()
|
||||
metadata_swa.max_seq_len_q = 1
|
||||
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
|
||||
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
|
||||
metadata_swa.cu_seqlens_k = cu_seqlens_k
|
||||
metadata_swa.page_table = page_table
|
||||
else:
|
||||
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
|
||||
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
|
||||
|
||||
metadata.swa_spec_metadata = metadata_swa
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_swa_spec_page_table_kernel(
|
||||
dst_ptr,
|
||||
src_a_ptr,
|
||||
src_b_ptr,
|
||||
seq_len_a_ptr,
|
||||
seq_len_b_ptr,
|
||||
dst_stride_m,
|
||||
dst_stride_n,
|
||||
a_stride_m,
|
||||
a_stride_n,
|
||||
b_stride_m,
|
||||
b_stride_n,
|
||||
LEN_A: tl.constexpr,
|
||||
LEN_B: tl.constexpr,
|
||||
REPEAT_STEP: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
|
||||
idx_a = pid_m // REPEAT_STEP
|
||||
idx_b = pid_m
|
||||
seq_len_a = tl.load(seq_len_a_ptr + idx_a)
|
||||
seq_len_b = tl.load(seq_len_b_ptr + idx_b)
|
||||
|
||||
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
total_len = seq_len_a + seq_len_b
|
||||
|
||||
if pid_n * BLOCK_N >= total_len:
|
||||
return
|
||||
|
||||
mask = offs_n < total_len
|
||||
dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
|
||||
|
||||
if (pid_n + 1) * BLOCK_N < seq_len_a:
|
||||
a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
|
||||
a_mask = mask & (offs_n < LEN_A)
|
||||
val = tl.load(a_ptr, mask=a_mask, other=0)
|
||||
tl.store(dst, val, mask=mask)
|
||||
elif pid_n * BLOCK_N >= seq_len_a:
|
||||
offs_b = offs_n - seq_len_a
|
||||
b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
|
||||
b_mask = mask & (offs_b < LEN_B)
|
||||
val = tl.load(b_ptr, mask=b_mask, other=0)
|
||||
tl.store(dst, val, mask=mask)
|
||||
else:
|
||||
# mixed part
|
||||
a_offs = offs_n
|
||||
a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
|
||||
a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
|
||||
a_val = tl.load(a_ptr, mask=a_mask, other=0)
|
||||
|
||||
b_offs = offs_n - seq_len_a
|
||||
b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
|
||||
b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
|
||||
b_val = tl.load(b_ptr, mask=b_mask, other=0)
|
||||
|
||||
result = tl.where(offs_n < seq_len_a, a_val, b_val)
|
||||
tl.store(dst, result, mask=mask)
|
||||
|
||||
|
||||
def prepare_swa_spec_page_table_triton(
|
||||
page_table_dst: torch.Tensor,
|
||||
page_table_a: torch.Tensor,
|
||||
page_table_b: torch.Tensor, # expand page table
|
||||
seq_len_a: torch.Tensor,
|
||||
seq_len_b: torch.Tensor, # expand seq lens
|
||||
speculative_num_draft_tokens: int,
|
||||
):
|
||||
# concat page_table and expand page_table by kv seq length
|
||||
bs = seq_len_a.numel()
|
||||
bs_expand = seq_len_b.numel()
|
||||
assert bs_expand == bs * speculative_num_draft_tokens
|
||||
|
||||
LEN_A = page_table_a.shape[1]
|
||||
LEN_B = page_table_b.shape[1]
|
||||
LEN_OUT = LEN_A + LEN_B
|
||||
REPEAT_STEP = speculative_num_draft_tokens
|
||||
BLOCK_N = 256
|
||||
|
||||
grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
|
||||
_prepare_swa_spec_page_table_kernel[grid](
|
||||
page_table_dst,
|
||||
page_table_a,
|
||||
page_table_b,
|
||||
seq_len_a,
|
||||
seq_len_b,
|
||||
page_table_dst.stride(0),
|
||||
page_table_dst.stride(1),
|
||||
page_table_a.stride(0),
|
||||
page_table_a.stride(1),
|
||||
page_table_b.stride(0),
|
||||
page_table_b.stride(1),
|
||||
LEN_A=LEN_A,
|
||||
LEN_B=LEN_B,
|
||||
REPEAT_STEP=REPEAT_STEP,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
|
||||
class FlashAttentionMultiStepBackend:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user