Files
sglang/python/sglang/srt/layers/attention/flashattention_backend.py
Yineng Zhang 0753ef831e [Auto Sync] Update flashattention_backend.py (20250922) (#10762)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Gordon Gustafson <ggustafson@together.ai>
2025-09-22 16:41:42 -07:00

2400 lines
103 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass
class FlashAttentionMetadata:
"""Metadata to be init once in the model forward pass,
each layer's forward pass can reuse the metadata.
For each init metadata function, we will try set up them in below order
"""
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
# Maximum sequence length for query
max_seq_len_q: int = 1
# Maximum sequence length for key
max_seq_len_k: int = 0
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None
# Window size (typically used by Gemma)
window_size: tuple = (-1, -1)
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None
# Encoder metadata
# Cumulative sequence lengths for encoder key
encoder_cu_seqlens_k: torch.Tensor = None
# Maximum sequence length for encoder key
encoder_max_seq_len_k: int = 0
# Sequence lengths for the forward batch
encoder_lens_int32: torch.Tensor = None
# Page table for the encoder
encoder_page_table: torch.Tensor = None
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
local_block_table: torch.Tensor = None # block table for local attention
local_max_query_len: int = 0 # max query length for local attention
local_max_seq_len: int = 0 # max sequence length for local attention
local_attn_metadata: Optional[LocalAttentionMetadata] = None
# For sliding window attention topk>1 spec decoding
swa_spec_metadata: Optional[FlashAttentionMetadata] = None
# Copied from:
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.Tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
"""
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
local attention blocks, where each block is passed to the attention kernel
as an independent local ("virtual") batch item.
Args:
attn_chunk_size: Size of local attention chunks
query_start_loc_np: Cumulative sum of query lengths (numpy array)
seq_lens_np: Sequence lengths (numpy array)
block_table: Block table for KV cache
page_size: Size of each page in the KV cache
Returns:
seqlens_q_local: Query sequence lengths for local attention
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
seqlens_k_local: Key sequence lengths for local attention
block_table_local: Block table for local attention
"""
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len = seq_lens_np.max()
effective_chunk_size = min(attn_chunk_size, max_seq_len)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size = (effective_chunk_size // page_size) * page_size
if effective_chunk_size < page_size:
effective_chunk_size = page_size
attn_chunk_size = effective_chunk_size
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
)
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // page_size
assert attn_chunk_size % page_size == 0, (
f"attn_chunk_size {attn_chunk_size} is not "
f"divisible by page_size {page_size}"
)
pages_per_local_batch = attn_chunk_size // page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices = np.broadcast_to(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch),
) + np.expand_dims(block_starts, axis=1)
# Ensure block_indices doesn't exceed block_table dimensions
# This is a critical safety check that prevents index out of bounds errors
# when dealing with large sequences (>8192 tokens) or when the block_table
# dimensions are smaller than what would be needed for the full attention chunk size.
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(
np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch,
)
block_table_local = block_table[batch_indices, block_indices].view(
virtual_batches, -1
)
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
@torch._dynamo.disable()
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
return merge_state_v2(o, s_a, o_exp, s_b)
class FlashAttentionBackend(AttentionBackend):
"""FlashAttention backend implementation.
Note about the init:
- If no spec decoding
- FlashAttentionBackend will be init once when the server starts.
- If spec decoding
- FlashAttentionBackend will be init once for the target worker
- FlashAttentionMultiStepBackend will be once for the draft worker
- It will spawn num_steps FlashAttentionBackend for the draft worker
Note about CUDA Graph:
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
- We don't support CUDA Graph for Extend and Draft Extend.
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
"""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
speculative_step_id=0,
topk=0,
speculative_num_steps=0,
fa_impl_ver=3,
):
super().__init__()
assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
self.forward_metadata: FlashAttentionMetadata = None
# extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
if self.is_hybrid:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
)
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
self.speculative_step_id = speculative_step_id
self.fa_impl_ver = fa_impl_ver
# Local attention settings
self.attention_chunk_size = (
model_runner.attention_chunk_size
if hasattr(model_runner, "attention_chunk_size")
else None
)
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
self.sliding_window_size = model_runner.sliding_window_size
self.has_swa = (
self.sliding_window_size is not None and self.sliding_window_size > -1
)
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
# We set nums splits to 1 if deterministic inference is enabled.
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata()
seqlens_in_batch = forward_batch.seq_lens
batch_size = forward_batch.batch_size
device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode
if forward_batch.spec_info is not None:
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.topk + 1,
step=self.topk,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata_expand = FlashAttentionMetadata()
decode_length = self.speculative_step_id + 1
metadata_expand.cache_seqlens_int32 = torch.full(
(seqlens_in_batch.numel() * self.topk,),
decode_length,
device=device,
dtype=torch.int32,
)
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() + 1,
dtype=torch.int32,
device=device,
)
metadata_expand.cu_seqlens_k = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
step=decode_length,
dtype=torch.int32,
device=device,
)
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
cache_loc = forward_batch.out_cache_loc.view(
-1, self.speculative_num_steps
)
metadata_expand.page_table = (
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
self.forward_metadata_spec_decode_expand = metadata_expand
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(forward_batch, metadata, device)
elif forward_batch.forward_mode.is_target_verify():
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(forward_batch, metadata, device)
else:
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata_expand = FlashAttentionMetadata()
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
+ 1,
dtype=torch.int32,
device=device,
)
# create expand page table
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(
forward_batch.seq_lens.numel(), -1
) + forward_batch.seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
(
forward_batch.seq_lens + self.speculative_num_draft_tokens
).repeat_interleave(self.speculative_num_draft_tokens),
dim=0,
),
(1, 0),
)[:-1]
mask_extraction_indices = (
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
+ cum_len[:, None]
).view(1, -1)
mask = forward_batch.spec_info.custom_mask[
mask_extraction_indices
].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
# shift table indices to avoid padding
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
# [8, 9, 10], [1, 1, 0],
# [8, 9, 10]] [1, 0, 1]]
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
# [8, 9, 0], [8, 9, 10],
# [8, 0, 10]] [8, 10, 9]]
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
# Build keys: if an entry is valid (mask==True), keep its original index;
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens
)
_, sort_order = torch.sort(keys, dim=1)
non_masked_page_table = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :
]
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
self.forward_metadata_spec_decode_expand = metadata_expand
if self.has_swa:
self._init_sliding_window_attn_spec_metadata(
metadata, metadata_expand
)
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
else:
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND:
self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
forward_batch.encoder_lens.numel() == 1
), "Only encoder size 1 is supported for now"
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]
# Currently only support forward_batch.encoder_lens.numel() == 1
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
self.forward_metadata = metadata
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
# Use precomputed metadata across all layers
metadata = self.forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
is_swa = (
layer.sliding_window_size is not None and layer.sliding_window_size > -1
)
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
and self.fa_impl_ver != 4
):
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
causal = not layer.is_cross_attention
# Check if we should use local attention
use_local_attn = (
self.attention_chunk_size is not None
and metadata.local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# We do cascade attention for Target Verify with topk > 1
# We don't use cascade attention for Sliding Window Attention:
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
use_cascade_attn = (
forward_batch.forward_mode.is_target_verify()
and self.topk > 1
and not is_swa
)
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks
# Get the appropriate page table based on whether we're using local attention
if use_local_attn:
local_metadata = metadata.local_attn_metadata
page_table = local_metadata.local_block_table
cu_seqlens_q = local_metadata.local_query_start_loc
cache_seqlens = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
elif is_swa and metadata.swa_spec_metadata is not None:
swa_spec_metadata = metadata.swa_spec_metadata
page_table = swa_spec_metadata.page_table
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
max_seqlen_q = swa_spec_metadata.max_seq_len_q
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
else:
page_table = metadata.page_table
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens = metadata.cache_seqlens_int32
max_seqlen_q = metadata.max_seq_len_q
cu_seqlens_k = metadata.cu_seqlens_k
# Use Flash Attention for prefill
if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else:
if (
forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not global_server_args_dict["disable_chunked_prefix_cache"]
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
**kwargs,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output
else:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
).to(q.dtype)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
result = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
)
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
# Use precomputed metadata across all layers
metadata = self.forward_metadata
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
use_local_attn = (
self.attention_chunk_size is not None
and local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# When Spec Decode enabled, forward_decode would be called with two mode:
# 1. DRAFT_DECODE: we enable cascade attention when top_k > 1
# 2. IDLE: we dont need cascade attention, spec_info will be none in this case
use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
window_size = (
(layer.sliding_window_size, 0)
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
else (-1, -1)
)
causal = not layer.is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
# Always use non-chunked logic for cross-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.encoder_page_table,
cache_seqlens=metadata.encoder_lens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=False,
window_size=(-1, -1),
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
num_splits=self.num_splits,
**kwargs,
)
elif use_local_attn:
# Use chunked (local) attention batching for self-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=local_attn_metadata.local_block_table,
cache_seqlens=local_attn_metadata.local_seqused_k,
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
cu_seqlens_k_new=None,
max_seqlen_q=local_attn_metadata.local_max_query_len,
softmax_scale=layer.scaling,
causal=True,
window_size=(-1, -1),
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
num_splits=self.num_splits,
**kwargs,
)
else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
q_reshaped = q.contiguous().view(
-1, layer.tp_q_head_num, layer.head_dim
)
# Default: single-token self-attention
result = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
max_seqlen_q = metadata.max_seq_len_q
result = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
num_splits=self.num_splits,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
# This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
# Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used
if self.attention_chunk_size is not None:
# Estimate maximum sizes for local attention metadata
max_seq_len = self.max_context_len
page_size = self.page_size or 1
attn_chunk_size = self.attention_chunk_size
max_virtual_batches = max_bs * (
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
)
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
self.decode_cuda_graph_local_attn_metadata = {
"local_query_start_loc": torch.zeros(
max_virtual_batches + 1, dtype=torch.int32, device=self.device
),
"local_seqused_k": torch.zeros(
max_virtual_batches, dtype=torch.int32, device=self.device
),
"local_block_table": torch.zeros(
max_virtual_batches,
max_pages_per_block,
dtype=torch.int32,
device=self.device,
),
}
# This is used by draft decode's first half of metadata when topk > 1
if self.topk > 1:
self.draft_decode_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.topk + 1,
step=self.topk,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
# This is used by draft decode's second half of metadata when topk > 1
decode_length = self.speculative_step_id + 1
self.draft_decode_metadata_topk_expand = {
"cache_seqlens": torch.full(
(max_bs * self.topk,),
decode_length,
device=self.device,
dtype=torch.int32,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.topk + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.arange(
0,
max_bs * self.topk * decode_length + 1,
step=decode_length,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.topk,
decode_length,
dtype=torch.int32,
device=self.device,
),
}
if (
self.speculative_num_draft_tokens is not None
and self.speculative_num_draft_tokens > 0
):
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
)
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
self.draft_extend_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.zeros(
max_bs + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
if self.topk > 1:
self.target_verify_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
self.target_verify_metadata_topk_expand = {
"cache_seqlens": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
}
if self.has_swa:
self.target_verify_metadata_topk_swa = {
"cache_seqlens": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
self.encoder_metadata = {
"encoder_page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"encoder_lens_int32": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"encoder_cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata()
# metadata_expand is needed for Spec Decoding when top k > 1
metadata_expand = FlashAttentionMetadata()
device = seq_lens.device
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
if self.topk <= 1:
# When topk = 1, we use the normal decode metadata
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens"
][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][:bs, :]
self.decode_cuda_graph_metadata[bs] = metadata
else:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata.cache_seqlens_int32 = (
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = seq_lens.max().item()
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
"cu_seqlens_k"
][: bs + 1]
metadata.page_table = self.draft_decode_metadata_topk_normal[
"page_table"
][:bs, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = (
self.draft_decode_metadata_topk_expand["cache_seqlens"][
: bs * self.topk
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = (
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
: bs * self.topk + 1
]
)
metadata_expand.cu_seqlens_k = (
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
: bs * self.topk + 1
]
)
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
"page_table"
][: bs * self.topk]
self.draft_decode_metadata_topk_normal[bs] = metadata
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
else:
# Normal Decode
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
batch_size = len(seq_lens)
device = seq_lens.device
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
:bs, :
]
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
self.decode_cuda_graph_metadata[bs] = metadata
if self.attention_chunk_size is not None:
self._update_local_attn_metadata_for_capture(metadata, batch_size)
elif forward_mode.is_target_verify():
if self.topk <= 1:
metadata.cache_seqlens_int32 = self.target_verify_metadata[
"cache_seqlens"
][:bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens)
)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
seq_lens.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
bs * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
self.target_verify_metadata[bs] = metadata
else:
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
"cache_seqlens"
][:bs]
metadata.max_seq_len_q = self.speculative_num_draft_tokens
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
"cu_seqlens_k"
][: bs + 1]
metadata.page_table = self.target_verify_metadata_topk_normal[
"page_table"
][:bs, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = (
self.target_verify_metadata_topk_expand["cache_seqlens"][
: bs * self.speculative_num_draft_tokens
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
"cu_seqlens_q"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
"cu_seqlens_k"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
"page_table"
][: bs * self.speculative_num_draft_tokens]
self.target_verify_metadata_topk_normal[bs] = metadata
self.target_verify_metadata_topk_expand[bs] = metadata_expand
if self.has_swa:
metadata_swa = FlashAttentionMetadata()
metadata_swa.cache_seqlens_int32 = (
self.target_verify_metadata_topk_swa["cache_seqlens"][
: bs * self.speculative_num_draft_tokens
]
)
metadata_swa.max_seq_len_q = 1
metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
"cu_seqlens_q"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
"cu_seqlens_k"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_swa.page_table = self.target_verify_metadata_topk_swa[
"page_table"
][: bs * self.speculative_num_draft_tokens]
self.target_verify_metadata_topk_swa[bs] = metadata_swa
metadata.swa_spec_metadata = metadata_swa
elif forward_mode.is_draft_extend():
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(seq_lens)
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
metadata.max_seq_len_k = seq_lens.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
self.draft_extend_metadata[bs] = metadata
if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
:encoder_bs
]
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
"encoder_cu_seqlens_k"
][: (encoder_bs + 1)]
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
:bs, :
]
self.forward_metadata = metadata
self.forward_metadata_spec_decode_expand = metadata_expand
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
):
"""Initialize forward metadata for replaying CUDA graph."""
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
device = seq_lens.device
metadata = None
metadata_expand = None
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
if self.topk <= 1:
# When topk = 1, we use the normal decode metadata
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
normal_decode_set_metadata(
metadata.cache_seqlens_int32,
metadata.cu_seqlens_k,
metadata.page_table,
self.req_to_token,
req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages,
seq_lens,
self.speculative_step_id + 1,
self.page_size,
)
else:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.draft_decode_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens)
# metadata.max_seq_len_q = self.topk, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
decode_length = self.speculative_step_id + 1
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length]
)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else:
# Normal Decode
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len
normal_decode_set_metadata(
metadata.cache_seqlens_int32,
metadata.cu_seqlens_k,
metadata.page_table,
self.req_to_token,
req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages,
seq_lens,
0,
self.page_size,
)
self._update_local_attn_metadata_for_replay(
metadata,
bs,
)
elif forward_mode.is_target_verify():
if self.topk <= 1:
metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens)
)
metadata.max_seq_len_k = (
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
else:
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.target_verify_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens)
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.target_verify_metadata_topk_expand[bs]
# metadata_expand.max_seq_len_q = 1, already set in capture
# metadata_expand.cu_seqlens_q already set in capture
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
(
seq_lens + self.speculative_num_draft_tokens
).repeat_interleave(self.speculative_num_draft_tokens),
dim=0,
),
(1, 0),
)[:-1]
mask_extraction_indices = (
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
+ cum_len[:, None]
).view(1, -1)
# avoid extracting padded seq indices which will be out of boundary
mask_extraction_indices[
:,
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
].fill_(0)
mask = spec_info.custom_mask[mask_extraction_indices].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
keys = torch.where(
mask,
col_indices,
col_indices + self.speculative_num_draft_tokens,
)
_, sort_order = torch.sort(keys, dim=1)
non_masked_page_table = (
self.req_to_token[req_pool_indices, :]
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table.copy_(
non_masked_page_table.gather(1, sort_order)
)
metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
metadata_expand.cu_seqlens_k[1:].copy_(
torch.cumsum(
metadata_expand.cache_seqlens_int32,
dim=0,
dtype=torch.int32,
)
)
if self.has_swa:
metadata_swa = self.target_verify_metadata_topk_swa[bs]
self._init_sliding_window_attn_spec_metadata(
metadata, metadata_expand, metadata_swa
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens)
metadata.max_seq_len_k = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
accept_length = spec_info.accept_length[:bs]
if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
else:
metadata.max_seq_len_q = 1
metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
if encoder_lens is not None:
# Only support encoder size 1 for now
metadata.encoder_max_seq_len_k = encoder_lens[0]
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
metadata.encoder_cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
)
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
)
# Update the regular page table
page_table = self.req_to_token[
req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
self.forward_metadata = metadata
self.forward_metadata_spec_decode_expand = metadata_expand
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def _init_local_attn_metadata(
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
if self.attention_chunk_size is None:
metadata.local_attn_metadata = None
return
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens_int32 = metadata.cache_seqlens_int32
if self.is_hybrid:
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
torch.int32
)
else:
page_table = metadata.page_table
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
metadata.local_attn_metadata = None
return
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seq_lens_np = cache_seqlens_int32.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seq_lens_np,
page_table,
self.page_size,
)
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
local_block_table=block_table_local.to(device),
local_max_query_len=int(seqlens_q_local_np.max()),
local_max_seq_len=int(seqlens_k_local_np.max()),
)
metadata.local_attn_metadata = local_metadata
def _update_local_attn_metadata_for_capture(
self, metadata: FlashAttentionMetadata, bs: int
):
"""Update local attention metadata during CUDA graph capture phase.
This method calculates the exact buffer sizes needed for local attention metadata
during the CUDA graph capture phase, optimizing memory usage by creating views of
pre-allocated buffers with exactly the sizes needed.
"""
seq_lens_capture = metadata.cache_seqlens_int32
max_seq_len = int(seq_lens_capture.max().item())
page_table_capture = metadata.page_table
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
seqlens_np = seq_lens_capture.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local_np,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seqlens_np,
page_table_capture,
self.page_size,
)
# Get exact dimensions from the calculation
q_len = len(cu_seqlens_q_local_np)
k_len = len(seqlens_k_local_np)
b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs
b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1
# Create views of the pre-allocated buffers with exactly these sizes
# This is the key optimization - we only use the memory we actually need
local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
][:q_len]
local_seqused_k = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"][
:k_len
]
local_block_table = self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
][:b0, :b1]
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=local_block_table,
local_max_query_len=1,
local_max_seq_len=max_seq_len,
)
def _update_local_attn_metadata_for_replay(
self,
metadata: FlashAttentionMetadata,
bs: int,
):
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
if self.attention_chunk_size is None:
return
# Access preallocated buffers
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
]
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
]
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
# Create a modified version for local attention that only processes the last token
# This mimics the normal decode pattern
cu_seqlens_q = torch.arange(
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
)
seqlens = metadata.cache_seqlens_int32[:bs]
# Slice the page_table to match the batch size and actual sequence length
# This serves three important purposes:
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
# 3. Prevents zeros in the block table which can cause garbage output during replay
#
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
# beyond the actual sequence length, leading to incorrect attention calculations
max_seq_len = int(seqlens.max().item())
if self.is_hybrid:
sliced_page_table = self.full_to_swa_index_mapping[
metadata.page_table[:bs, :max_seq_len]
].to(torch.int32)
else:
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seqlens_np = seqlens.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seqlens_np,
sliced_page_table,
self.page_size,
)
# Convert back to tensors
device = local_q_buf.device
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
block_table_local = block_table_local.to(device)
# Get sizes
q_len = cu_seqlens_q_local.shape[0]
k_len = seqlens_k_local.shape[0]
b0, b1 = block_table_local.shape
# In-place updates into preallocated tensors and zero out the unused space
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
local_q_buf[q_len:].fill_(0)
local_k_buf[:k_len].copy_(seqlens_k_local)
local_k_buf[k_len:].fill_(0)
local_block_buf[:b0, :b1].copy_(block_table_local)
local_block_buf[b0:, :].fill_(0)
local_block_buf[:b0, b1:].fill_(0)
if metadata.local_attn_metadata is not None:
lam = metadata.local_attn_metadata
lam.local_max_query_len = int(seqlens_q_local_np.max())
lam.local_max_seq_len = int(seqlens_k_local_np.max())
def _init_sliding_window_attn_spec_metadata(
self,
metadata: FlashAttentionMetadata,
metadata_expand: FlashAttentionMetadata,
metadata_swa: Optional[FlashAttentionMetadata] = None,
):
# TODO: support page_size > 1 for swa spec
assert (
self.page_size == 1
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
cache_seqlens_int32 = (
metadata.cache_seqlens_int32.repeat_interleave(
self.speculative_num_draft_tokens
)
+ metadata_expand.cache_seqlens_int32
)
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
)
bs = cache_seqlens_int32.shape[0]
page_table = (
metadata.page_table.new_zeros(
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
)
if metadata_swa is None
else metadata_swa.page_table
)
prepare_swa_spec_page_table_triton(
page_table,
metadata.page_table,
metadata_expand.page_table,
metadata.cache_seqlens_int32,
metadata_expand.cache_seqlens_int32,
self.speculative_num_draft_tokens,
)
if metadata_swa is None:
metadata_swa = FlashAttentionMetadata()
metadata_swa.max_seq_len_q = 1
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
metadata_swa.cu_seqlens_k = cu_seqlens_k
metadata_swa.page_table = page_table
else:
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
metadata.swa_spec_metadata = metadata_swa
@triton.jit
def _prepare_swa_spec_page_table_kernel(
dst_ptr,
src_a_ptr,
src_b_ptr,
seq_len_a_ptr,
seq_len_b_ptr,
dst_stride_m,
dst_stride_n,
a_stride_m,
a_stride_n,
b_stride_m,
b_stride_n,
LEN_A: tl.constexpr,
LEN_B: tl.constexpr,
REPEAT_STEP: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
idx_a = pid_m // REPEAT_STEP
idx_b = pid_m
seq_len_a = tl.load(seq_len_a_ptr + idx_a)
seq_len_b = tl.load(seq_len_b_ptr + idx_b)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
total_len = seq_len_a + seq_len_b
if pid_n * BLOCK_N >= total_len:
return
mask = offs_n < total_len
dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
if (pid_n + 1) * BLOCK_N < seq_len_a:
a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
a_mask = mask & (offs_n < LEN_A)
val = tl.load(a_ptr, mask=a_mask, other=0)
tl.store(dst, val, mask=mask)
elif pid_n * BLOCK_N >= seq_len_a:
offs_b = offs_n - seq_len_a
b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
b_mask = mask & (offs_b < LEN_B)
val = tl.load(b_ptr, mask=b_mask, other=0)
tl.store(dst, val, mask=mask)
else:
# mixed part
a_offs = offs_n
a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
a_val = tl.load(a_ptr, mask=a_mask, other=0)
b_offs = offs_n - seq_len_a
b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
b_val = tl.load(b_ptr, mask=b_mask, other=0)
result = tl.where(offs_n < seq_len_a, a_val, b_val)
tl.store(dst, result, mask=mask)
def prepare_swa_spec_page_table_triton(
page_table_dst: torch.Tensor,
page_table_a: torch.Tensor,
page_table_b: torch.Tensor, # expand page table
seq_len_a: torch.Tensor,
seq_len_b: torch.Tensor, # expand seq lens
speculative_num_draft_tokens: int,
):
# concat page_table and expand page_table by kv seq length
bs = seq_len_a.numel()
bs_expand = seq_len_b.numel()
assert bs_expand == bs * speculative_num_draft_tokens
LEN_A = page_table_a.shape[1]
LEN_B = page_table_b.shape[1]
LEN_OUT = LEN_A + LEN_B
REPEAT_STEP = speculative_num_draft_tokens
BLOCK_N = 256
grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
_prepare_swa_spec_page_table_kernel[grid](
page_table_dst,
page_table_a,
page_table_b,
seq_len_a,
seq_len_b,
page_table_dst.stride(0),
page_table_dst.stride(1),
page_table_a.stride(0),
page_table_a.stride(1),
page_table_b.stride(0),
page_table_b.stride(1),
LEN_A=LEN_A,
LEN_B=LEN_B,
REPEAT_STEP=REPEAT_STEP,
BLOCK_N=BLOCK_N,
num_warps=4,
)
class FlashAttentionMultiStepBackend:
def __init__(
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
):
self.model_runner = model_runner
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashAttentionBackend(
model_runner,
speculative_step_id=i,
topk=self.topk,
speculative_num_steps=self.speculative_num_steps,
)
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
# TODO: incrementally update the metadata for the later steps,
# so that they do not need to recompute everything from scratch.
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
out_cache_loc=forward_batch.out_cache_loc,
)
# @torch.compile(dynamic=True, backend=get_compiler_backend())
# TODO: fuse these kernels
# NOTE: torch.compile makes it slower in speculative decoding
def normal_decode_set_metadata(
cache_seqlens_int32: torch.Tensor,
cu_seqlens_k: torch.Tensor,
page_table: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
strided_indices: torch.Tensor,
max_seq_pages: torch.Tensor,
seq_lens: torch.Tensor,
seq_len_delta: int,
page_size: int,
):
cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
page_indices = req_to_token[
req_pool_indices[:, None],
strided_indices[:max_seq_pages][None, :],
]
page_table[:, :max_seq_pages].copy_(page_indices // page_size)