1091 lines
40 KiB
Python
1091 lines
40 KiB
Python
from __future__ import annotations
|
|
|
|
"""
|
|
Support attention backend for flashinfer MLA.
|
|
The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
|
|
When it's set to false, all wrappers are BatchMLAPaged wrapper.
|
|
When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
|
|
and uses BatchMLAPaged wrapper for decoding.
|
|
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
|
"""
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
|
|
import torch
|
|
|
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
|
import logging
|
|
|
|
torch._logging.set_logs(dynamo=logging.ERROR)
|
|
torch._dynamo.config.suppress_errors = True
|
|
|
|
from sglang.global_config import global_config
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
from sglang.srt.layers.utils import is_sm100_supported
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
|
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.speculative.spec_info import SpecInfo
|
|
|
|
if is_flashinfer_available():
|
|
from flashinfer import (
|
|
BatchMLAPagedAttentionWrapper,
|
|
BatchPrefillWithRaggedKVCacheWrapper,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DecodeMetadata:
|
|
decode_wrapper: BatchMLAPagedAttentionWrapper
|
|
|
|
|
|
@dataclass
|
|
class PrefillMetadata:
|
|
prefill_wrapper: BatchMLAPagedAttentionWrapper
|
|
use_ragged: bool
|
|
|
|
|
|
# Reuse this workspace buffer across all flashinfer wrappers
|
|
global_workspace_buffer = None
|
|
|
|
|
|
class FlashInferMhaChunkKVRunner:
|
|
def __init__(
|
|
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
|
|
):
|
|
# Parse Constants
|
|
self.num_local_heads = (
|
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
)
|
|
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
|
self.v_head_dim = model_runner.model_config.v_head_dim
|
|
self.data_type = model_runner.dtype
|
|
self.q_data_type = model_runner.dtype
|
|
|
|
# Buffers and wrappers
|
|
self.qo_indptr = attn_backend.qo_indptr
|
|
self.workspace_buffer = attn_backend.workspace_buffer
|
|
self.fmha_backend = attn_backend.fmha_backend
|
|
|
|
self.chunk_ragged_wrappers = []
|
|
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
|
|
|
|
def update_prefix_chunks(self, num_prefix_chunks: int):
|
|
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
|
|
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
|
|
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
|
)
|
|
self.chunk_ragged_wrappers.append(ragged_wrapper)
|
|
|
|
def update_wrapper(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
):
|
|
assert forward_batch.num_prefix_chunks is not None
|
|
num_prefix_chunks = forward_batch.num_prefix_chunks
|
|
self.update_prefix_chunks(num_prefix_chunks)
|
|
|
|
prefix_lens = forward_batch.extend_prefix_lens
|
|
seq_lens = forward_batch.seq_lens
|
|
|
|
bs = len(seq_lens)
|
|
qo_indptr = self.qo_indptr
|
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
|
qo_indptr = qo_indptr[: bs + 1]
|
|
|
|
for chunk_idx in range(forward_batch.num_prefix_chunks):
|
|
# MHA for chunked prefix kv cache when running model with MLA
|
|
assert forward_batch.prefix_chunk_idx is not None
|
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
|
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
|
|
|
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
|
|
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
|
wrapper.begin_forward(
|
|
qo_indptr=qo_indptr,
|
|
kv_indptr=kv_indptr,
|
|
num_qo_heads=self.num_local_heads,
|
|
num_kv_heads=self.num_local_heads,
|
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
|
head_dim_vo=self.v_head_dim,
|
|
q_data_type=self.q_data_type,
|
|
causal=False,
|
|
)
|
|
# ragged prefill
|
|
self.ragged_wrapper.begin_forward(
|
|
qo_indptr=qo_indptr,
|
|
kv_indptr=qo_indptr,
|
|
num_qo_heads=self.num_local_heads,
|
|
num_kv_heads=self.num_local_heads,
|
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
|
head_dim_vo=self.v_head_dim,
|
|
q_data_type=self.q_data_type,
|
|
causal=True,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
):
|
|
logits_soft_cap = layer.logit_cap
|
|
if forward_batch.attn_attend_prefix_cache:
|
|
chunk_idx = forward_batch.prefix_chunk_idx
|
|
assert chunk_idx >= 0
|
|
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
|
o1, s1 = wrapper.forward_return_lse(
|
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
|
causal=False,
|
|
sm_scale=layer.scaling,
|
|
logits_soft_cap=logits_soft_cap,
|
|
)
|
|
else:
|
|
o1, s1 = self.ragged_wrapper.forward_return_lse(
|
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
|
causal=True,
|
|
sm_scale=layer.scaling,
|
|
logits_soft_cap=logits_soft_cap,
|
|
)
|
|
|
|
return o1, s1
|
|
|
|
|
|
class FlashInferMLAAttnBackend(AttentionBackend):
|
|
"""Flashinfer attention kernels."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
skip_prefill: bool = False,
|
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
super().__init__()
|
|
# Parse constants
|
|
self.max_context_len = model_runner.model_config.context_len
|
|
self.device = model_runner.device
|
|
self.skip_prefill = skip_prefill
|
|
self.enable_chunk_kv = (
|
|
not skip_prefill
|
|
and global_server_args_dict["disaggregation_mode"] != "decode"
|
|
and not global_server_args_dict["disable_chunked_prefix_cache"]
|
|
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
|
)
|
|
self.page_size = model_runner.page_size
|
|
|
|
# Allocate buffers
|
|
global global_workspace_buffer
|
|
if global_workspace_buffer is None:
|
|
# different from flashinfer zero_init_global_workspace_buffer
|
|
global_workspace_buffer = torch.empty(
|
|
global_config.flashinfer_workspace_size,
|
|
dtype=torch.uint8,
|
|
device=model_runner.device,
|
|
)
|
|
self.workspace_buffer = global_workspace_buffer
|
|
|
|
max_bs = model_runner.req_to_token_pool.size
|
|
if kv_indptr_buf is None:
|
|
self.kv_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
else:
|
|
self.kv_indptr = kv_indptr_buf
|
|
|
|
self.kv_indices = torch.empty(
|
|
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
|
|
dtype=torch.int32,
|
|
device=model_runner.device,
|
|
)
|
|
|
|
if not self.skip_prefill:
|
|
self.qo_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
|
|
if q_indptr_decode_buf is None:
|
|
# A hack to pre-initialize large batch size for dp attention
|
|
if model_runner.server_args.enable_dp_attention:
|
|
max_bs = model_runner.server_args.dp_size * max_bs
|
|
self.q_indptr_decode = torch.arange(
|
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
|
)
|
|
|
|
else:
|
|
self.q_indptr_decode = q_indptr_decode_buf
|
|
|
|
self.fmha_backend = "auto"
|
|
if is_sm100_supported():
|
|
self.fmha_backend = "cutlass"
|
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
|
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
|
)
|
|
|
|
if not self.skip_prefill:
|
|
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
|
self.workspace_buffer,
|
|
backend="auto",
|
|
)
|
|
|
|
# FlashinferMLA backend uses mla wrapper for target verify
|
|
self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
|
|
self.workspace_buffer,
|
|
backend="auto",
|
|
)
|
|
|
|
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
|
|
self.workspace_buffer, backend="auto"
|
|
)
|
|
|
|
# Create indices updater
|
|
if not skip_prefill:
|
|
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
|
model_runner, self
|
|
)
|
|
if self.enable_chunk_kv:
|
|
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
|
|
|
|
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
|
model_runner, self
|
|
)
|
|
|
|
# Other metadata
|
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
|
self.decode_cuda_graph_metadata = {}
|
|
self.prefill_cuda_graph_metadata = {} # For verify
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
|
|
if forward_batch.forward_mode.is_decode_or_idle():
|
|
self.indices_updater_decode.update(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens_sum,
|
|
decode_wrapper=self.decode_wrapper,
|
|
init_metadata_replay=False,
|
|
)
|
|
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
|
elif forward_batch.forward_mode.is_draft_extend():
|
|
self.indices_updater_prefill.update(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens_sum,
|
|
prefix_lens=None,
|
|
prefill_wrapper_paged=self.prefill_wrapper_paged,
|
|
use_ragged=False,
|
|
spec_info=forward_batch.spec_info,
|
|
)
|
|
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
|
|
elif forward_batch.forward_mode.is_target_verify():
|
|
self.indices_updater_prefill.update(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens_sum,
|
|
prefix_lens=None,
|
|
prefill_wrapper_paged=self.prefill_wrapper_verify,
|
|
use_ragged=False,
|
|
spec_info=forward_batch.spec_info,
|
|
)
|
|
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
|
|
else:
|
|
prefix_lens = forward_batch.extend_prefix_lens
|
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
|
use_ragged = (
|
|
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
|
and extend_no_prefix
|
|
)
|
|
|
|
self.indices_updater_prefill.update(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens_sum,
|
|
prefix_lens,
|
|
prefill_wrapper_paged=self.prefill_wrapper_paged,
|
|
use_ragged=use_ragged,
|
|
)
|
|
self.forward_metadata = PrefillMetadata(
|
|
self.prefill_wrapper_paged, use_ragged
|
|
)
|
|
|
|
def init_cuda_graph_state(
|
|
self,
|
|
max_bs: int,
|
|
max_num_tokens: int,
|
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
self.cuda_graph_kv_indices = (
|
|
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
|
|
)
|
|
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
|
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
|
self.cuda_graph_kv_lens = torch.ones(
|
|
(max_bs,), dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
# For fast decode plan in graph replaying
|
|
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
|
|
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
|
|
self.fast_decode_kwargs = {
|
|
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
|
|
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
|
|
"kv_indices": self.cuda_graph_kv_indices,
|
|
}
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
):
|
|
|
|
if forward_mode.is_decode_or_idle():
|
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
|
self.workspace_buffer,
|
|
use_cuda_graph=True,
|
|
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
|
|
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
|
|
kv_indices=self.cuda_graph_kv_indices,
|
|
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
|
backend="auto",
|
|
)
|
|
seq_lens_sum = seq_lens.sum().item()
|
|
self.indices_updater_decode.update(
|
|
req_pool_indices,
|
|
seq_lens,
|
|
seq_lens_sum,
|
|
decode_wrapper=decode_wrapper,
|
|
init_metadata_replay=False,
|
|
spec_info=spec_info,
|
|
)
|
|
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
|
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
|
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
|
|
elif forward_mode.is_target_verify():
|
|
verify_wrapper = BatchMLAPagedAttentionWrapper(
|
|
self.workspace_buffer,
|
|
use_cuda_graph=True,
|
|
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
|
|
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
|
|
kv_indices=self.cuda_graph_kv_indices,
|
|
kv_len_arr=self.cuda_graph_kv_lens[:bs],
|
|
backend="auto",
|
|
)
|
|
seq_lens_sum = seq_lens.sum().item()
|
|
self.indices_updater_prefill.update(
|
|
req_pool_indices,
|
|
seq_lens,
|
|
seq_lens_sum,
|
|
prefix_lens=None,
|
|
prefill_wrapper_paged=verify_wrapper,
|
|
use_ragged=False,
|
|
spec_info=spec_info,
|
|
)
|
|
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
|
|
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
|
|
elif forward_mode.is_draft_extend():
|
|
draft_extend_wrapper = BatchMLAPagedAttentionWrapper(
|
|
self.workspace_buffer,
|
|
use_cuda_graph=True,
|
|
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
|
|
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
|
|
kv_indices=self.cuda_graph_kv_indices,
|
|
kv_len_arr=self.cuda_graph_kv_lens[:bs],
|
|
backend="auto",
|
|
)
|
|
seq_lens_sum = seq_lens.sum().item()
|
|
self.indices_updater_prefill.update(
|
|
req_pool_indices,
|
|
seq_lens,
|
|
seq_lens_sum,
|
|
prefix_lens=None,
|
|
prefill_wrapper_paged=draft_extend_wrapper,
|
|
use_ragged=False,
|
|
spec_info=spec_info,
|
|
)
|
|
self.prefill_cuda_graph_metadata[bs] = draft_extend_wrapper
|
|
self.forward_metadata = PrefillMetadata(draft_extend_wrapper, False)
|
|
else:
|
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
|
|
if forward_mode.is_decode_or_idle():
|
|
assert seq_lens_cpu is not None
|
|
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
|
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
|
|
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
|
num_pages_per_req, dim=0
|
|
)
|
|
self.fast_decode_kwargs.update(
|
|
{
|
|
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
|
|
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
|
|
"kv_len_arr_cpu": kv_len_arr_cpu,
|
|
}
|
|
)
|
|
self.indices_updater_decode.update(
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
seq_lens_sum,
|
|
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
|
init_metadata_replay=True,
|
|
spec_info=spec_info,
|
|
**self.fast_decode_kwargs,
|
|
)
|
|
elif forward_mode.is_target_verify():
|
|
self.indices_updater_prefill.update(
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
seq_lens_sum,
|
|
prefix_lens=None,
|
|
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
|
|
use_ragged=False,
|
|
spec_info=spec_info,
|
|
)
|
|
elif forward_mode.is_draft_extend():
|
|
self.indices_updater_prefill.update(
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
seq_lens_sum,
|
|
prefix_lens=None,
|
|
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
|
|
use_ragged=False,
|
|
spec_info=spec_info,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
return 1
|
|
|
|
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
|
"""Init the metadata for a forward pass."""
|
|
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
q_rope: Optional[torch.Tensor] = None,
|
|
k_rope: Optional[torch.Tensor] = None,
|
|
):
|
|
if (
|
|
forward_batch.attn_attend_prefix_cache is not None
|
|
and forward_batch.mha_return_lse
|
|
): # MHA Chunk
|
|
assert self.enable_chunk_kv
|
|
assert q_rope is None
|
|
assert k_rope is None
|
|
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
|
|
return o1, s1
|
|
|
|
cache_loc = forward_batch.out_cache_loc
|
|
logits_soft_cap = layer.logit_cap
|
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
|
|
|
# Save kv cache
|
|
if save_kv_cache and k is not None:
|
|
assert v is not None
|
|
if save_kv_cache:
|
|
if k_rope is not None:
|
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
|
layer, cache_loc, k, k_rope
|
|
)
|
|
else:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
|
if q_rope is not None:
|
|
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
|
q_rope = q_rope.view(
|
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
|
)
|
|
if self.forward_metadata.use_ragged:
|
|
# ragged prefill
|
|
if q_rope is not None:
|
|
q = torch.cat([q, q_rope], dim=-1)
|
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
if k_rope is not None:
|
|
k = torch.cat([k, k_rope], dim=-1)
|
|
o = self.prefill_wrapper_ragged.forward(
|
|
qall,
|
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
|
v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
|
causal=True,
|
|
sm_scale=layer.scaling,
|
|
logits_soft_cap=logits_soft_cap,
|
|
)
|
|
else:
|
|
# mla paged prefill
|
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
|
q.dtype
|
|
)
|
|
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
|
|
|
if q_rope is None:
|
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
q, q_rope = (
|
|
qall[:, :, : layer.v_head_dim],
|
|
qall[:, :, layer.v_head_dim :],
|
|
)
|
|
o = q.new_empty(q.shape)
|
|
o = prefill_wrapper_paged.run(
|
|
q,
|
|
q_rope,
|
|
k_buf[:, :, : layer.v_head_dim],
|
|
k_buf[:, :, layer.v_head_dim :],
|
|
out=o,
|
|
)
|
|
|
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
# For multi-head latent attention
|
|
q_rope: Optional[torch.Tensor] = None,
|
|
k_rope: Optional[torch.Tensor] = None,
|
|
):
|
|
decode_wrapper = self.forward_metadata.decode_wrapper
|
|
cache_loc = forward_batch.out_cache_loc
|
|
|
|
if k is not None:
|
|
assert v is not None
|
|
if save_kv_cache:
|
|
if k_rope is not None:
|
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
|
layer,
|
|
cache_loc,
|
|
k,
|
|
k_rope,
|
|
)
|
|
else:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer,
|
|
cache_loc,
|
|
k,
|
|
v,
|
|
)
|
|
|
|
# Reshape inputs
|
|
if q_rope is not None:
|
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
|
q_rope = q_rope.view(
|
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
|
)
|
|
else:
|
|
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
|
|
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
|
q.dtype
|
|
)
|
|
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
|
|
|
o = q_nope.new_empty(q_nope.shape)
|
|
o = decode_wrapper.run(
|
|
q_nope,
|
|
q_rope,
|
|
k_buf[:, :, : layer.v_head_dim],
|
|
k_buf[:, :, layer.v_head_dim :],
|
|
out=o,
|
|
)
|
|
|
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
|
|
|
|
class FlashInferMLAIndicesUpdaterDecode:
|
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
|
# Parse Constants
|
|
self.num_local_heads = (
|
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
)
|
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
|
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
|
self.scaling = model_runner.model_config.scaling
|
|
self.data_type = model_runner.dtype
|
|
self.attn_backend = attn_backend
|
|
self.page_size = model_runner.page_size
|
|
# Buffers and wrappers
|
|
self.kv_indptr = attn_backend.kv_indptr
|
|
self.kv_indices = attn_backend.kv_indices
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
self.q_indptr = attn_backend.q_indptr_decode
|
|
|
|
def update(
|
|
self,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
|
init_metadata_replay: bool = False,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
|
**fast_decode_kwargs,
|
|
):
|
|
decode_wrapper = decode_wrapper or self.decode_wrapper
|
|
self.call_begin_forward(
|
|
decode_wrapper,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
seq_lens_sum,
|
|
self.q_indptr,
|
|
self.kv_indptr,
|
|
init_metadata_replay,
|
|
spec_info,
|
|
**fast_decode_kwargs,
|
|
)
|
|
|
|
def call_begin_forward(
|
|
self,
|
|
wrapper: BatchMLAPagedAttentionWrapper,
|
|
req_pool_indices: torch.Tensor,
|
|
paged_kernel_lens: torch.Tensor,
|
|
paged_kernel_lens_sum: int,
|
|
q_indptr: torch.Tensor,
|
|
kv_indptr: torch.Tensor,
|
|
init_metadata_replay: bool = False,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
|
**fast_decode_kwargs,
|
|
):
|
|
bs = len(req_pool_indices)
|
|
q_indptr = q_indptr[: bs + 1]
|
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
|
sm_scale = self.scaling
|
|
if spec_info is None:
|
|
num_pages_per_req = (
|
|
paged_kernel_lens + self.page_size - 1
|
|
) // self.page_size
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = (
|
|
self.kv_indices[: kv_indptr[-1]]
|
|
if not init_metadata_replay
|
|
else fast_decode_kwargs["kv_indices"]
|
|
)
|
|
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.shape[1],
|
|
self.page_size,
|
|
)
|
|
else:
|
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
|
|
|
if not init_metadata_replay:
|
|
wrapper.plan(
|
|
qo_indptr=q_indptr,
|
|
kv_indptr=kv_indptr,
|
|
kv_indices=kv_indices,
|
|
kv_len_arr=kv_lens,
|
|
num_heads=self.num_local_heads,
|
|
head_dim_ckv=self.kv_lora_rank,
|
|
head_dim_kpe=self.qk_rope_head_dim,
|
|
page_size=self.page_size,
|
|
causal=False,
|
|
sm_scale=sm_scale,
|
|
q_data_type=self.data_type,
|
|
kv_data_type=self.data_type,
|
|
)
|
|
else:
|
|
wrapper.plan(
|
|
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
|
|
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
|
|
kv_indices=kv_indices,
|
|
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
|
|
num_heads=self.num_local_heads,
|
|
head_dim_ckv=self.kv_lora_rank,
|
|
head_dim_kpe=self.qk_rope_head_dim,
|
|
page_size=self.page_size,
|
|
causal=False,
|
|
sm_scale=sm_scale,
|
|
q_data_type=self.data_type,
|
|
kv_data_type=self.data_type,
|
|
)
|
|
|
|
|
|
class FlashInferMLAIndicesUpdaterPrefill:
|
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
|
# Parse Constants
|
|
self.num_local_heads = (
|
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
)
|
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
|
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
|
self.v_head_dim = model_runner.model_config.v_head_dim
|
|
self.scaling = model_runner.model_config.scaling
|
|
self.data_type = model_runner.dtype
|
|
self.q_data_type = model_runner.dtype
|
|
self.attn_backend = attn_backend
|
|
|
|
# Buffers and wrappers
|
|
self.kv_indptr = attn_backend.kv_indptr
|
|
self.qo_indptr = attn_backend.qo_indptr
|
|
self.kv_indices = attn_backend.kv_indices
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
|
self.page_size = model_runner.page_size
|
|
|
|
def update(
|
|
self,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
prefix_lens: torch.Tensor,
|
|
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
|
use_ragged: bool,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
|
):
|
|
if use_ragged:
|
|
paged_kernel_lens = prefix_lens
|
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
|
else:
|
|
paged_kernel_lens = seq_lens
|
|
paged_kernel_lens_sum = seq_lens_sum
|
|
self.call_begin_forward(
|
|
self.prefill_wrapper_ragged,
|
|
prefill_wrapper_paged,
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
paged_kernel_lens_sum,
|
|
seq_lens,
|
|
prefix_lens,
|
|
self.kv_indptr,
|
|
self.qo_indptr,
|
|
use_ragged,
|
|
spec_info,
|
|
)
|
|
|
|
def call_begin_forward(
|
|
self,
|
|
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
|
wrapper_paged: BatchMLAPagedAttentionWrapper,
|
|
req_pool_indices: torch.Tensor,
|
|
paged_kernel_lens: torch.Tensor,
|
|
paged_kernel_lens_sum: int,
|
|
seq_lens: torch.Tensor,
|
|
prefix_lens: torch.Tensor,
|
|
kv_indptr: torch.Tensor,
|
|
qo_indptr: torch.Tensor,
|
|
use_ragged: bool,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
|
):
|
|
bs = len(seq_lens)
|
|
sm_scale = self.scaling
|
|
|
|
if spec_info is None:
|
|
assert len(seq_lens) == len(req_pool_indices)
|
|
num_pages_per_req = (
|
|
paged_kernel_lens + self.page_size - 1
|
|
) // self.page_size
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = self.kv_indices[: kv_indptr[-1]]
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.shape[1],
|
|
self.page_size,
|
|
)
|
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
|
qo_indptr = qo_indptr[: bs + 1]
|
|
custom_mask = None
|
|
else:
|
|
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
|
spec_info, EagleVerifyInput
|
|
)
|
|
# TODO: Support topk > 1 with custom mask
|
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
|
spec_info.generate_attn_arg_prefill(
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
paged_kernel_lens_sum,
|
|
self.req_to_token,
|
|
)
|
|
)
|
|
if use_ragged:
|
|
# ragged prefill
|
|
wrapper_ragged.begin_forward(
|
|
qo_indptr=qo_indptr,
|
|
kv_indptr=qo_indptr,
|
|
num_qo_heads=self.num_local_heads,
|
|
num_kv_heads=self.num_local_heads,
|
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
|
head_dim_vo=self.v_head_dim,
|
|
q_data_type=self.q_data_type,
|
|
causal=True,
|
|
)
|
|
else:
|
|
# mla paged prefill
|
|
if spec_info is not None:
|
|
assert (
|
|
self.page_size == 1
|
|
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
|
|
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
|
|
else:
|
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
|
wrapper_paged.plan(
|
|
qo_indptr=qo_indptr,
|
|
kv_indptr=kv_indptr,
|
|
kv_indices=kv_indices,
|
|
kv_len_arr=kv_lens,
|
|
num_heads=self.num_local_heads,
|
|
head_dim_ckv=self.kv_lora_rank,
|
|
head_dim_kpe=self.qk_rope_head_dim,
|
|
page_size=self.page_size,
|
|
causal=True,
|
|
sm_scale=sm_scale,
|
|
q_data_type=self.q_data_type,
|
|
kv_data_type=self.data_type,
|
|
)
|
|
|
|
|
|
class FlashInferMLAMultiStepDraftBackend:
|
|
"""
|
|
Wrap multiple flashinfer mla attention backends as one for multiple consecutive
|
|
draft decoding steps.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
topk: int,
|
|
speculative_num_steps: int,
|
|
):
|
|
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
|
|
|
if topk > 1:
|
|
raise ValueError(
|
|
"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
|
)
|
|
self.topk = topk
|
|
self.speculative_num_steps = speculative_num_steps
|
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
|
|
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
|
self.kv_indptr = torch.zeros(
|
|
(
|
|
self.speculative_num_steps,
|
|
max_bs + 1,
|
|
),
|
|
dtype=torch.int32,
|
|
device=model_runner.device,
|
|
)
|
|
self.q_indptr_decode = torch.arange(
|
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
|
)
|
|
|
|
self.attn_backends = []
|
|
for i in range(self.speculative_num_steps):
|
|
self.attn_backends.append(
|
|
FlashInferMLAAttnBackend(
|
|
model_runner,
|
|
skip_prefill=True,
|
|
kv_indptr_buf=self.kv_indptr[i],
|
|
q_indptr_decode_buf=self.q_indptr_decode,
|
|
)
|
|
)
|
|
|
|
self.max_context_len = self.attn_backends[0].max_context_len
|
|
|
|
# Cached variables for generate_draft_decode_kv_indices
|
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
|
self.page_size = model_runner.server_args.page_size
|
|
|
|
def common_template(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
kv_indices_buffer: torch.Tensor,
|
|
call_fn: Callable,
|
|
):
|
|
num_seqs = forward_batch.batch_size
|
|
bs = self.topk * num_seqs
|
|
seq_lens_sum = forward_batch.seq_lens_sum
|
|
|
|
self.generate_draft_decode_kv_indices[
|
|
(self.speculative_num_steps, num_seqs, self.topk)
|
|
](
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.req_to_token_pool.req_to_token,
|
|
forward_batch.seq_lens,
|
|
kv_indices_buffer,
|
|
self.kv_indptr,
|
|
forward_batch.positions,
|
|
self.pool_len,
|
|
kv_indices_buffer.shape[1],
|
|
self.kv_indptr.shape[1],
|
|
next_power_of_2(num_seqs),
|
|
next_power_of_2(self.speculative_num_steps),
|
|
next_power_of_2(bs),
|
|
self.page_size,
|
|
)
|
|
|
|
assert forward_batch.spec_info is not None
|
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
|
|
|
for i in range(self.speculative_num_steps - 1):
|
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
|
]
|
|
call_fn(i, forward_batch)
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
|
|
kv_indices = torch.zeros(
|
|
(
|
|
self.speculative_num_steps,
|
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
|
),
|
|
dtype=torch.int32,
|
|
device="cuda",
|
|
)
|
|
|
|
def call_fn(i, forward_batch):
|
|
assert forward_batch.spec_info is not None
|
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
|
forward_batch.spec_info.kv_indptr = (
|
|
forward_batch.spec_info.kv_indptr.clone()
|
|
)
|
|
forward_batch.spec_info.kv_indices = (
|
|
forward_batch.spec_info.kv_indices.clone()
|
|
)
|
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
|
|
|
self.common_template(forward_batch, kv_indices, call_fn)
|
|
|
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
self.cuda_graph_kv_indices = torch.zeros(
|
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
|
dtype=torch.int32,
|
|
device="cuda",
|
|
)
|
|
|
|
for i in range(self.speculative_num_steps):
|
|
self.attn_backends[i].init_cuda_graph_state(
|
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
|
)
|
|
|
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
|
|
|
def call_fn(i, forward_batch):
|
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
|
forward_batch.batch_size,
|
|
forward_batch.batch_size * self.topk,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
encoder_lens=None,
|
|
forward_mode=ForwardMode.DECODE,
|
|
spec_info=forward_batch.spec_info,
|
|
)
|
|
|
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self, forward_batch: ForwardBatch, bs: int
|
|
):
|
|
|
|
def call_fn(i, forward_batch):
|
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
|
bs,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
seq_lens_sum=-1,
|
|
encoder_lens=None,
|
|
forward_mode=ForwardMode.DECODE,
|
|
spec_info=forward_batch.spec_info,
|
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
|
)
|
|
|
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
|
|
|
|
|
def fast_mla_decode_plan(
|
|
self,
|
|
qo_indptr_cpu: torch.Tensor,
|
|
kv_indptr_cpu: torch.Tensor,
|
|
kv_indices: torch.Tensor,
|
|
kv_len_arr_cpu: torch.Tensor,
|
|
num_heads: int,
|
|
head_dim_ckv: int,
|
|
head_dim_kpe: int,
|
|
page_size: int,
|
|
causal: bool,
|
|
sm_scale: float,
|
|
q_data_type: torch.dtype,
|
|
kv_data_type: torch.dtype,
|
|
) -> None:
|
|
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
|
|
for skipping the stream synchronization in original plan function during
|
|
cuda graph replaying.
|
|
"""
|
|
self._causal = causal
|
|
self._page_size = page_size
|
|
self._sm_scale = sm_scale
|
|
|
|
try:
|
|
# Standard version with just the required arguments (no use_profiler)
|
|
self._cached_module.plan.default(
|
|
self._float_workspace_buffer,
|
|
self._int_workspace_buffer,
|
|
self._pin_memory_int_workspace_buffer,
|
|
qo_indptr_cpu,
|
|
kv_indptr_cpu,
|
|
kv_len_arr_cpu,
|
|
num_heads,
|
|
head_dim_ckv,
|
|
causal,
|
|
)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|