Files
sglang/python/sglang/srt/layers/attention/flashattention_backend.py
2025-04-15 14:45:15 -07:00

1114 lines
46 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel.flash_attn import flash_attn_with_kvcache
@dataclass
class FlashAttentionMetadata:
"""Metadata to be init once in the model forward pass,
each layer's forward pass can reuse the metadata.
For each init metadata function, we will try set up them in below order
"""
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
# Maximum sequence length for query
max_seq_len_q: int = 0
# Maximum sequence length for key
max_seq_len_k: int = 0
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None
# Window size (typically used by Gemma)
window_size: tuple = (-1, -1)
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None
# Encoder metadata
# Cumulative sequence lengths for encoder key
encoder_cu_seqlens_k: torch.Tensor = None
# Maximum sequence length for encoder key
encoder_max_seq_len_k: int = 0
# Sequence lengths for the forward batch
encoder_lens_int32: torch.Tensor = None
# Page table for the encoder
encoder_page_table: torch.Tensor = None
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
local_block_table: torch.Tensor = None # block table for local attention
local_max_query_len: int = 0 # max query length for local attention
local_max_seq_len: int = 0 # max sequence length for local attention
local_attn_metadata: Optional[LocalAttentionMetadata] = None
# Copied from:
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.Tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
"""
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
local attention blocks, where each block is passed to the attention kernel
as an independent local ("virtual") batch item.
Args:
attn_chunk_size: Size of local attention chunks
query_start_loc_np: Cumulative sum of query lengths (numpy array)
seq_lens_np: Sequence lengths (numpy array)
block_table: Block table for KV cache
page_size: Size of each page in the KV cache
Returns:
seqlens_q_local: Query sequence lengths for local attention
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
seqlens_k_local: Key sequence lengths for local attention
block_table_local: Block table for local attention
"""
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
)
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // page_size
assert attn_chunk_size % page_size == 0, (
f"attn_chunk_size {attn_chunk_size} is not "
f"divisible by page_size {page_size}"
)
pages_per_local_batch = attn_chunk_size // page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices = np.broadcast_to(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch),
) + np.expand_dims(block_starts, axis=1)
# Ensure block_indices doesn't exceed block_table dimensions
# This is a critical safety check that prevents index out of bounds errors
# when dealing with large sequences (>8192 tokens) or when the block_table
# dimensions are smaller than what would be needed for the full attention chunk size.
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(
np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch,
)
block_table_local = block_table[batch_indices, block_indices].view(
virtual_batches, -1
)
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
class FlashAttentionBackend(AttentionBackend):
"""FlashAttention backend implementation.
Note about the init:
- If no spec decoding
- FlashAttentionBackend will be init once when the server starts.
- If spec decoding
- FlashAttentionBackend will be init once for the target worker
- FlashAttentionMultiStepBackend will be once for the draft worker
- It will spawn num_steps FlashAttentionBackend for the draft worker
Note about CUDA Graph:
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
- We don't support CUDA Graph for Extend and Draft Extend.
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
"""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
speculative_step_id=0,
topk=0,
speculative_num_steps=0,
):
super().__init__()
assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
self.forward_metadata: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"])
self.skip_prefill = skip_prefill
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
self.speculative_step_id = speculative_step_id
# Local attention settings
self.attention_chunk_size = (
model_runner.attention_chunk_size
if hasattr(model_runner, "attention_chunk_size")
else None
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata()
seqlens_in_batch = forward_batch.seq_lens
batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode
if forward_batch.spec_info is not None:
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
elif forward_batch.forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
else:
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled
if (
self.attention_chunk_size is not None
and forward_batch.forward_mode == ForwardMode.EXTEND
):
# Convert tensors to numpy for local attention processing
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len = seq_lens_np.max()
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size = (
effective_chunk_size // self.page_size
) * self.page_size
if effective_chunk_size < self.page_size:
effective_chunk_size = self.page_size
# Create local attention metadata
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
effective_chunk_size,
cu_seqlens_q_np,
seq_lens_np,
metadata.page_table,
self.page_size,
)
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
device
),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
local_block_table=block_table_local,
local_max_query_len=seqlens_q_local_np.max(),
local_max_seq_len=seqlens_k_local_np.max(),
)
metadata.local_attn_metadata = local_metadata
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
forward_batch.encoder_lens.numel() == 1
), "Only encoder size 1 is supported for now"
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]
# Currently only support forward_batch.encoder_lens.numel() == 1
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
self.forward_metadata = metadata
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
# Use precomputed metadata across all layers
metadata = self.forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
window_size = (
(layer.sliding_window_size, 0)
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
else (-1, -1)
)
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
causal = not layer.is_cross_attention
# Check if we should use local attention
use_local_attn = (
self.attention_chunk_size is not None
and metadata.local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# Get the appropriate page table based on whether we're using local attention
if use_local_attn:
local_metadata = metadata.local_attn_metadata
page_table = local_metadata.local_block_table
cu_seqlens_q = local_metadata.local_query_start_loc
cache_seqlens = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
else:
page_table = metadata.page_table
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens = metadata.cache_seqlens_int32
max_seqlen_q = metadata.max_seq_len_q
max_seqlen_k = metadata.max_seq_len_k
cu_seqlens_k = metadata.cu_seqlens_k
# Use Flash Attention for prefill
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
) -> torch.Tensor:
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
# Use precomputed metadata across all layers
metadata = self.forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
window_size = (
(layer.sliding_window_size, 0)
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
else (-1, -1)
)
causal = not layer.is_cross_attention
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto":
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
o = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"page_table_draft_decode": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
self.encoder_metadata = {
"encoder_page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"encoder_lens_int32": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"encoder_cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata()
device = seq_lens.device
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens"
][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
: bs + 1
]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][req_pool_indices, :]
else:
# Normal Decode
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
batch_size = len(seq_lens)
device = seq_lens.device
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, :
]
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
seq_lens.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
bs * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.target_verify_metadata["page_table"][
req_pool_indices, :
]
self.target_verify_metadata[bs] = metadata
if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
:encoder_bs
]
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
"encoder_cu_seqlens_k"
][: (encoder_bs + 1)]
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
req_pool_indices, :
]
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None,
):
# """Initialize forward metadata for replaying CUDA graph."""
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
if forward_mode.is_decode_or_idle():
metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None:
# Draft Decode
metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
)
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
else:
# Normal Decode
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
elif forward_mode.is_target_verify():
metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
)
metadata.max_seq_len_k = (
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
if encoder_lens is not None:
# Only support encoder size 1 for now
metadata.encoder_max_seq_len_k = encoder_lens[0]
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
metadata.encoder_cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
)
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
)
# Update the regular page table
page_table = self.req_to_token[
req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
self.forward_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 0
class FlashAttentionMultiStepBackend:
def __init__(
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
):
self.model_runner = model_runner
self.topk = topk
self.speculative_num_steps = speculative_num_steps
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
assert (
self.topk == 1
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashAttentionBackend(
model_runner,
speculative_step_id=i,
topk=self.topk,
speculative_num_steps=self.speculative_num_steps,
)
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs)
def init_forward_metadata_capture_cuda_graph(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
out_cache_loc=forward_batch.out_cache_loc,
)