Add Support for Page Size greater than 1 for Flashinfer MLA Backend (#8593)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-08-21 18:15:06 -07:00
committed by GitHub
parent 0b3a5b1151
commit 3cc3d9b950
5 changed files with 292 additions and 105 deletions

View File

@@ -24,9 +24,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -72,11 +70,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.skip_prefill = skip_prefill
self.page_size = model_runner.page_size
# Allocate buffers
global global_workspace_buffer
@@ -97,15 +95,25 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else:
self.kv_indptr = kv_indptr_buf
self.kv_indices = torch.empty(
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
dtype=torch.int32,
device=model_runner.device,
)
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
if q_indptr_decode_buf is None:
# A hack to pre-initialize large batch size for dp attention
if model_runner.server_args.enable_dp_attention:
max_bs = model_runner.server_args.dp_size * max_bs
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
else:
self.q_indptr_decode = q_indptr_decode_buf
@@ -148,6 +156,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.prefill_cuda_graph_metadata = {} # For verify
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
@@ -205,16 +214,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_kv_indices = (
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
)
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones(
@@ -240,6 +242,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
@@ -250,7 +253,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
@@ -321,11 +323,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
kv_len_arr_cpu = seq_lens_cpu[:bs]
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
num_pages_per_req, dim=0
)
self.fast_decode_kwargs.update(
{
@@ -334,7 +338,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
@@ -381,7 +384,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
@@ -401,7 +403,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
if self.forward_metadata.use_ragged:
# ragged prefill
if q_rope is not None:
@@ -422,6 +423,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
if q_rope is None:
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q, q_rope = (
@@ -483,17 +486,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
o = q_nope.new_empty(q_nope.shape)
# Direct call to run without the wrapper
o = decode_wrapper.run(
q_nope,
q_rope,
k_buffer[:, :, : layer.v_head_dim],
k_buffer[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
out=o,
)
@@ -512,9 +515,10 @@ class FlashInferMLAIndicesUpdaterDecode:
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.dtype
self.attn_backend = attn_backend
self.page_size = model_runner.page_size
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_indices = attn_backend.kv_indices
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode
@@ -558,13 +562,17 @@ class FlashInferMLAIndicesUpdaterDecode:
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
num_pages_per_req = (
paged_kernel_lens + self.page_size - 1
) // self.page_size
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
self.kv_indices[: kv_indptr[-1]]
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
@@ -573,39 +581,40 @@ class FlashInferMLAIndicesUpdaterDecode:
None,
kv_indices,
self.req_to_token.shape[1],
self.page_size,
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
if not init_metadata_replay:
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
qo_indptr=q_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
kv_len_arr=kv_lens,
num_heads=self.num_local_heads,
head_dim_ckv=self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=sm_scale,
q_data_type=self.data_type,
kv_data_type=self.data_type,
)
else:
wrapper.plan(
fast_decode_kwargs["qo_indptr_cpu"],
fast_decode_kwargs["kv_indptr_cpu"],
kv_indices,
fast_decode_kwargs["kv_len_arr_cpu"],
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
kv_indices=kv_indices,
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
num_heads=self.num_local_heads,
head_dim_ckv=self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=sm_scale,
q_data_type=self.data_type,
kv_data_type=self.data_type,
)
@@ -627,12 +636,14 @@ class FlashInferMLAIndicesUpdaterPrefill:
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.qo_indptr = attn_backend.qo_indptr
self.kv_indices = attn_backend.kv_indices
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
self.page_size = model_runner.page_size
def update(
self,
req_pool_indices: torch.Tnesor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
@@ -646,7 +657,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
else:
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward(
self.prefill_wrapper_ragged,
prefill_wrapper_paged,
@@ -680,13 +690,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
num_pages_per_req = (
paged_kernel_lens + self.page_size - 1
) // self.page_size
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
kv_indices = self.kv_indices[: kv_indptr[-1]]
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
@@ -695,6 +704,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
None,
kv_indices,
self.req_to_token.shape[1],
self.page_size,
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
@@ -712,7 +722,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.req_to_token,
)
)
if use_ragged:
# ragged prefill
wrapper_ragged.begin_forward(
@@ -726,20 +735,26 @@ class FlashInferMLAIndicesUpdaterPrefill:
)
else:
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
if spec_info is not None:
assert (
self.page_size == 1
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
else:
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
True,
sm_scale,
self.q_data_type,
self.data_type,
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
kv_len_arr=kv_lens,
num_heads=self.num_local_heads,
head_dim_ckv=self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim,
page_size=self.page_size,
causal=True,
sm_scale=sm_scale,
q_data_type=self.q_data_type,
kv_data_type=self.data_type,
)
@@ -834,6 +849,7 @@ class FlashInferMLAMultiStepDraftBackend:
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
self.speculative_num_steps,
@@ -869,6 +885,7 @@ class FlashInferMLAMultiStepDraftBackend:
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
@@ -885,6 +902,7 @@ class FlashInferMLAMultiStepDraftBackend:
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,

View File

@@ -9,18 +9,89 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
PAGE_SIZE: tl.constexpr = 1,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
"""
Create KV indices for FlashInfer attention backend.
# find the req pool idx, this is for batch to token
This Triton kernel builds a lookup table that maps from logical request/token
coordinates to physical token locations in the global KV cache pool. It's used
by FlashInfer attention backends to efficiently access scattered KV cache data.
The kernel processes each request in parallel and converts the req_to_token
lookup table into a flat list of token indices that can be used by attention kernels.
general idea:
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
fixed number of pages)]
max_pages = max_context_len / PAGED_SIZE
kv_indices_ptr will store the flat list of the pages used by each request
Args:
Inputs Arguments (non mutable):
req_to_token_ptr: Request to token location look up table
Shape: [max_batch, max_context_len]
req_pool_indices_ptr: Request to pool index look up table. Each request uses
one pool.
Shape: [batch_size]
page_kernel_lens_ptr: sequence lengths per request
Shape: [batch_size]
kv_indptr: Should be computed based on number of pages used by each request.
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
per request.
Shape: [batch_size + 1]
kv_indptr[i] = start index in kv_indices for request i
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
Can be None. If provided, adds offset to token positions.
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
Equal to max_context_len.
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
Outputs:
kv_indices_ptr: Pointer to output array where KV indices will be stored.
Shape:[total-num-pages],
where total_num_pages = sum(seq_lens // PAGED_SIZE)
Example:
If we have:
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
The kernel will output:
If PAGE_SIZE = 1:
packed
- kv_indptr (passed in as input arg): [0,3,5]
- kv_indices = [10, 11, 12, 20, 21]
padded - max_pages is 10 tokens per req
- kv_indptr (passed in as input arg): [0,10, 20]
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
If PAGE_SIZE = 2
packed:
- kv_indptr (passed in as input arg): [0,3,4]
- kv_indices = [5,6,10]
padded: max_pages is 4
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
- kv_indices = [5, 6, -1, -1,
10, -1, -1, -1]
This allows attention kernels to directly access the correct KV cache
entries for each request's tokens.
"""
BLOCK_SIZE: tl.constexpr = 512
NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
@@ -31,19 +102,27 @@ def create_flashinfer_kv_indices_triton(
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
# index into req_to_token_ptr needs to be int64
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
kv_range = kv_end - kv_start
num_pages = tl.cdiv(kv_range, PAGE_SIZE)
num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
req_to_token_block_start = (
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
)
for i in range(num_loops):
token_offsets_in_block = (
tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
) * PAGE_SIZE
page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
valid_tokens = token_offsets_in_block < kv_range
valid_pages = page_offsets_in_block < num_pages
token_numbers = tl.load(
req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
)
tl.store(
kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
mask=valid_pages,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
@triton.jit

View File

@@ -636,6 +636,10 @@ class ServerArgs:
logger.warning(
"DeepSeek MTP does not require setting speculative_draft_model_path."
)
if self.page_size != 1 and self.attention_backend == "flashinfer":
raise ValueError(
"Speculative decoding with page_size != 1 is not supported. Please set page_size to 1."
)
# Auto choose parameters
if self.speculative_num_steps is None: