Files
sglang/python/sglang/srt/layers/attention/flashinfer_mla_backend.py

1091 lines
40 KiB
Python

from __future__ import annotations
"""
Support attention backend for flashinfer MLA.
The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
When it's set to false, all wrappers are BatchMLAPaged wrapper.
When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
import os
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import logging
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available():
from flashinfer import (
BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
@dataclass
class DecodeMetadata:
decode_wrapper: BatchMLAPagedAttentionWrapper
@dataclass
class PrefillMetadata:
prefill_wrapper: BatchMLAPagedAttentionWrapper
use_ragged: bool
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None
class FlashInferMhaChunkKVRunner:
def __init__(
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.data_type = model_runner.dtype
self.q_data_type = model_runner.dtype
# Buffers and wrappers
self.qo_indptr = attn_backend.qo_indptr
self.workspace_buffer = attn_backend.workspace_buffer
self.fmha_backend = attn_backend.fmha_backend
self.chunk_ragged_wrappers = []
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
def update_prefix_chunks(self, num_prefix_chunks: int):
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=self.fmha_backend
)
self.chunk_ragged_wrappers.append(ragged_wrapper)
def update_wrapper(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.num_prefix_chunks is not None
num_prefix_chunks = forward_batch.num_prefix_chunks
self.update_prefix_chunks(num_prefix_chunks)
prefix_lens = forward_batch.extend_prefix_lens
seq_lens = forward_batch.seq_lens
bs = len(seq_lens)
qo_indptr = self.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
for chunk_idx in range(forward_batch.num_prefix_chunks):
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
wrapper = self.chunk_ragged_wrappers[chunk_idx]
wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=False,
)
# ragged prefill
self.ragged_wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=True,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
logits_soft_cap = layer.logit_cap
if forward_batch.attn_attend_prefix_cache:
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
wrapper = self.chunk_ragged_wrappers[chunk_idx]
o1, s1 = wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
else:
o1, s1 = self.ragged_wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
return o1, s1
class FlashInferMLAAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.skip_prefill = skip_prefill
self.enable_chunk_kv = (
not skip_prefill
and global_server_args_dict["disaggregation_mode"] != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"]
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
)
self.page_size = model_runner.page_size
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer
global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.kv_indices = torch.empty(
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
dtype=torch.int32,
device=model_runner.device,
)
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
if q_indptr_decode_buf is None:
# A hack to pre-initialize large batch size for dp attention
if model_runner.server_args.enable_dp_attention:
max_bs = model_runner.server_args.dp_size * max_bs
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
else:
self.q_indptr_decode = q_indptr_decode_buf
self.fmha_backend = "auto"
if is_sm100_supported():
self.fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=self.fmha_backend
)
if not self.skip_prefill:
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
# FlashinferMLA backend uses mla wrapper for target verify
self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto"
)
# Create indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self
)
if self.enable_chunk_kv:
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self
)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_paged,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_verify,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
else:
prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and extend_no_prefix
)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
prefill_wrapper_paged=self.prefill_wrapper_paged,
use_ragged=use_ragged,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrapper_paged, use_ragged
)
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_kv_indices = (
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
)
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones(
(max_bs,), dtype=torch.int32, device=self.device
)
# For fast decode plan in graph replaying
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
self.fast_decode_kwargs = {
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
"kv_indices": self.cuda_graph_kv_indices,
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
spec_info=spec_info,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
elif forward_mode.is_target_verify():
verify_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:bs],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=verify_wrapper,
use_ragged=False,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
elif forward_mode.is_draft_extend():
draft_extend_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:bs],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=draft_extend_wrapper,
use_ragged=False,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = draft_extend_wrapper
self.forward_metadata = PrefillMetadata(draft_extend_wrapper, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
kv_len_arr_cpu = seq_lens_cpu[:bs]
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
num_pages_per_req, dim=0
)
self.fast_decode_kwargs.update(
{
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
spec_info=spec_info,
**self.fast_decode_kwargs,
)
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
spec_info=spec_info,
)
elif forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
spec_info=spec_info,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
def get_cuda_graph_seq_len_fill_value(self):
return 1
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if (
forward_batch.attn_attend_prefix_cache is not None
and forward_batch.mha_return_lse
): # MHA Chunk
assert self.enable_chunk_kv
assert q_rope is None
assert k_rope is None
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
return o1, s1
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
# Save kv cache
if save_kv_cache and k is not None:
assert v is not None
if save_kv_cache:
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, cache_loc, k, k_rope
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if q_rope is not None:
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
if self.forward_metadata.use_ragged:
# ragged prefill
if q_rope is not None:
q = torch.cat([q, q_rope], dim=-1)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
o = self.prefill_wrapper_ragged.forward(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
else:
# mla paged prefill
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
if q_rope is None:
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q, q_rope = (
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
)
o = q.new_empty(q.shape)
o = prefill_wrapper_paged.run(
q,
q_rope,
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
out=o,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
# Reshape inputs
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
o = q_nope.new_empty(q_nope.shape)
o = decode_wrapper.run(
q_nope,
q_rope,
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
out=o,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class FlashInferMLAIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.dtype
self.attn_backend = attn_backend
self.page_size = model_runner.page_size
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_indices = attn_backend.kv_indices
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
decode_wrapper,
req_pool_indices,
seq_lens,
seq_lens_sum,
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
spec_info,
**fast_decode_kwargs,
)
def call_begin_forward(
self,
wrapper: BatchMLAPagedAttentionWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
if spec_info is None:
num_pages_per_req = (
paged_kernel_lens + self.page_size - 1
) // self.page_size
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = (
self.kv_indices[: kv_indptr[-1]]
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
self.page_size,
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
if not init_metadata_replay:
wrapper.plan(
qo_indptr=q_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
kv_len_arr=kv_lens,
num_heads=self.num_local_heads,
head_dim_ckv=self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=sm_scale,
q_data_type=self.data_type,
kv_data_type=self.data_type,
)
else:
wrapper.plan(
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
kv_indices=kv_indices,
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
num_heads=self.num_local_heads,
head_dim_ckv=self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=sm_scale,
q_data_type=self.data_type,
kv_data_type=self.data_type,
)
class FlashInferMLAIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.dtype
self.q_data_type = model_runner.dtype
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.qo_indptr = attn_backend.qo_indptr
self.kv_indices = attn_backend.kv_indices
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
self.page_size = model_runner.page_size
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
if use_ragged:
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward(
self.prefill_wrapper_ragged,
prefill_wrapper_paged,
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
self.kv_indptr,
self.qo_indptr,
use_ragged,
spec_info,
)
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: BatchMLAPagedAttentionWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
seq_lens: torch.Tensor,
prefix_lens: torch.Tensor,
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
bs = len(seq_lens)
sm_scale = self.scaling
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
num_pages_per_req = (
paged_kernel_lens + self.page_size - 1
) // self.page_size
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.kv_indices[: kv_indptr[-1]]
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
self.page_size,
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
# TODO: Support topk > 1 with custom mask
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
if use_ragged:
# ragged prefill
wrapper_ragged.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=True,
)
else:
# mla paged prefill
if spec_info is not None:
assert (
self.page_size == 1
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
else:
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper_paged.plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
kv_len_arr=kv_lens,
num_heads=self.num_local_heads,
head_dim_ckv=self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim,
page_size=self.page_size,
causal=True,
sm_scale=sm_scale,
q_data_type=self.q_data_type,
kv_data_type=self.data_type,
)
class FlashInferMLAMultiStepDraftBackend:
"""
Wrap multiple flashinfer mla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
if topk > 1:
raise ValueError(
"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashInferMLAAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
def common_template(
self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
kv_indptr_cpu: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr_cpu: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale
try:
# Standard version with just the required arguments (no use_profiler)
self._cached_module.plan.default(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
)
except Exception as e:
raise RuntimeError(f"Error in alternate MLA plan: {e}")