Files
sglang/python/sglang/srt/layers/attention/nsa_backend.py
fzyzcjy efbc687c28 Support DeepSeek V3.2 Exp (#11061)
Co-authored-by: Stefan He <11166516+hebiao064@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <95566987+hnyls2002@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <56809903+fridge003@users.noreply.github.com>
Co-authored-by: DarkSharpness <76582120+darksharpness@users.noreply.github.com>
Co-authored-by: ZhengdQin <46387172+zhengdqin@users.noreply.github.com>
Co-authored-by: DarkSharpness <2040703891@qq.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Zhengda Qin <zhengdqin@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
2025-10-06 00:24:15 -07:00

888 lines
32 KiB
Python

from __future__ import annotations
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
import torch
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.layers.attention.nsa.transform_index import (
transform_index_page_table_decode,
transform_index_page_table_prefill,
)
from sglang.srt.layers.attention.nsa.utils import (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
NSA_FUSE_TOPK,
compute_nsa_seqlens,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
_is_hip = is_hip()
if _is_hip:
try:
from aiter import (
flash_attn_varlen_func,
mha_batch_prefill_func,
paged_attention_ragged,
)
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else:
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass(frozen=True)
class NSAFlashMLAMetadata:
"""Metadata only needed by FlashMLA"""
flashmla_metadata: torch.Tensor
num_splits: torch.Tensor
def slice(self, sli):
return NSAFlashMLAMetadata(
flashmla_metadata=self.flashmla_metadata,
num_splits=self.num_splits[sli],
)
def copy_(self, other: "NSAFlashMLAMetadata"):
self.flashmla_metadata.copy_(other.flashmla_metadata)
self.num_splits.copy_(other.num_splits)
@dataclass(frozen=True)
class NSAMetadata:
page_size: int
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor
# Maximum sequence length for query
max_seq_len_q: int
# Maximum sequence length for key
max_seq_len_k: int
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor
# Page table, the index of KV Cache Tables/Blocks
# this table is always with page_size = 1
page_table_1: torch.Tensor
# NOTE(dark): This will property be used in:
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
real_page_table: torch.Tensor
# NSA metadata (nsa prefill are expanded)
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
nsa_extend_seq_lens_list: List[int]
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
@dataclass(frozen=True)
class NSAIndexerMetadata(BaseIndexerMetadata):
attn_metadata: NSAMetadata
def get_seqlens_int32(self) -> torch.Tensor:
return self.attn_metadata.cache_seqlens_int32
def get_page_table_64(self) -> torch.Tensor:
return self.attn_metadata.real_page_table
def get_seqlens_expanded(self) -> torch.Tensor:
return self.attn_metadata.nsa_seqlens_expanded
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
if not NSA_FUSE_TOPK:
return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
# NOTE(dark): if fused, we return a transformed page table directly
return fast_topk_transform_fused(
score=logits,
lengths=self.get_seqlens_expanded(),
page_table_size_1=self.attn_metadata.page_table_1,
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
topk=topk,
)
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
return torch.nn.functional.pad(
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
)
_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
]
NSA_PREFILL_IMPL: _NSA_IMPL_T
NSA_DECODE_IMPL: _NSA_IMPL_T
class NativeSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata: NSAMetadata
self.device = model_runner.device
assert isinstance(model_runner.page_size, int)
self.real_page_size = model_runner.page_size
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
self.nsa_kv_cache_store_fp8 = (
model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
)
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
self.max_context_len = model_runner.model_config.context_len
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
assert model_runner.req_to_token_pool is not None
self.req_to_token = model_runner.req_to_token_pool.req_to_token
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
if _is_hip:
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
def get_device_int32_arange(self, l: int) -> torch.Tensor:
if l > len(self._arange_buf):
next_pow_of_2 = 1 << (l - 1).bit_length()
self._arange_buf = torch.arange(
next_pow_of_2, device=self.device, dtype=torch.int32
)
return self._arange_buf[:l]
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
page_size = self.real_page_size
if page_size == 1:
return page_table
max_seqlen_k = page_table.shape[1]
strided_indices = torch.arange(
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
)
return page_table[:, strided_indices] // page_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
batch_size = forward_batch.batch_size
device = forward_batch.seq_lens.device
assert (
forward_batch.spec_info is None
), "Spec decoding is not supported for NSA backend now"
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
assert forward_batch.seq_lens_cpu is not None
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :max_seqlen_k
]
if forward_batch.forward_mode.is_decode_or_idle():
extend_seq_lens_cpu = [1] * batch_size
max_seqlen_q = 1
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
seqlens_expanded = cache_seqlens_int32
elif forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.extend_seq_lens is not None
and forward_batch.extend_prefix_lens_cpu is not None
), "All of them must not be None"
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
assert forward_batch.extend_seq_lens is not None
if any(forward_batch.extend_prefix_lens_cpu):
max_seqlen_q = max(extend_seq_lens_cpu)
cu_seqlens_q = compute_cu_seqlens(
forward_batch.extend_seq_lens.to(torch.int32)
)
else:
max_seqlen_q = max_seqlen_k
cu_seqlens_q = cu_seqlens_k
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=device,
)
for qo_len, kv_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.seq_lens_cpu.tolist(),
strict=True,
)
]
)
else:
assert False, f"Unsupported {forward_batch.forward_mode = }"
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
original_seq_lens=seqlens_expanded,
nsa_index_topk=self.nsa_index_topk,
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table,
flashmla_metadata=(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=seqlens_expanded,
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
real_page_table=self._transform_table_1_to_real(page_table),
)
self.forward_metadata = metadata
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self.decode_cuda_graph_metadata: Dict = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
# fake page_table for sparse_prefill
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"flashmla_metadata": (
self._compute_flashmla_metadata(
cache_seqlens=torch.ones(
max_bs, dtype=torch.int32, device=self.device
),
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
):
"""Initialize forward metadata for capturing CUDA graph."""
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
# Normal Decode
# Get sequence information
cache_seqlens_int32 = seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
# Use max context length for seq_len_k
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
max_seq_len_k = page_table_1.shape[1]
# Precompute page table
# Precompute cumulative sequence lengths
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
real_page_table = self._transform_table_1_to_real(page_table_1)
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs + 1))
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
else:
flashmla_metadata = None
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=1,
max_seq_len_k=max_seq_len_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table_1,
flashmla_metadata=flashmla_metadata,
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=cache_seqlens_int32,
real_page_table=real_page_table,
nsa_extend_seq_lens_list=[1] * bs,
)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert seq_lens_cpu is not None
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
# Normal Decode
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
max_len = int(seq_lens_cpu.max().item())
cache_seqlens = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_len]
metadata.page_table_1[:, :max_len].copy_(page_indices)
assert (
metadata.nsa_cache_seqlens_int32 is not None
and metadata.nsa_cu_seqlens_k is not None
and self.nsa_index_topk is not None
)
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
metadata.nsa_cu_seqlens_k[1:].copy_(
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
)
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
assert self.real_page_size == metadata.page_size
if self.real_page_size > 1:
real_table = self._transform_table_1_to_real(page_indices)
new_len = real_table.shape[1]
metadata.real_page_table[:, :new_len].copy_(real_table)
else:
assert metadata.real_page_table is metadata.page_table_1
if NSA_DECODE_IMPL == "flashmla_decode":
metadata.flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
self.forward_metadata = metadata
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert (
not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
), "NSA backend doesn't support speculative decoding"
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
# Do absorbed multi-latent attention
assert q_rope is not None
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# when store in fp8 and compute in fp8, no need to convert dtype
if not (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and self.nsa_kv_cache_store_fp8
):
kv_cache = kv_cache.to(q.dtype)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
# NOTE(dark): here, we use page size = 1
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
assert metadata.nsa_extend_seq_lens_list is not None
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
if NSA_PREFILL_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
metadata=metadata,
page_table_1=page_table_1,
)
elif NSA_PREFILL_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
else:
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
page_size=1,
)
if NSA_DECODE_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
metadata=metadata,
page_table_1=page_table_1,
)
elif NSA_DECODE_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
elif NSA_DECODE_IMPL == "aiter":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_aiter(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
layer=layer,
metadata=metadata,
bs=forward_batch.batch_size,
)
else:
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
def _forward_fa3(
self,
q_rope: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
q_nope: torch.Tensor,
page_table: torch.Tensor,
cache_seqlens: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
sm_scale: float,
logit_cap: float,
page_size: int,
) -> torch.Tensor:
k_rope_cache = kv_cache[:, :, v_head_dim:]
c_kv_cache = kv_cache[:, :, :v_head_dim]
qk_rope_dim = k_rope_cache.shape[-1]
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=sm_scale,
causal=True,
softcap=logit_cap,
return_softmax_lse=False,
num_splits=self.num_splits,
)
return o # type: ignore
def _forward_flashmla_prefill(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
from flash_mla import flash_mla_sparse_fwd
o, _, _ = flash_mla_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
return o
def _forward_flashmla_decode(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
sm_scale: float,
layer,
metadata: NSAMetadata,
page_table_1,
) -> torch.Tensor:
from flash_mla import flash_mla_with_kvcache
cache_seqlens = metadata.nsa_cache_seqlens_int32
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
assert self.real_page_size == 64, "only page size 64 is supported"
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not self.nsa_kv_cache_store_fp8:
# inefficiently quantize the whole cache
kv_cache = quantize_k_cache(kv_cache)
indices = page_table_1.unsqueeze(1)
assert (
indices.shape[-1] == self.nsa_index_topk
) # requirement of FlashMLA decode kernel
o, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv_cache,
cache_seqlens=cache_seqlens,
head_dim_v=v_head_dim,
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
num_splits=metadata.flashmla_metadata.num_splits,
softmax_scale=sm_scale,
indices=indices,
# doc says it is not used, but if pass in None then error
block_table=torch.empty(
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
),
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
)
return o
def _forward_tilelang(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
return tilelang_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
def _forward_aiter(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
page_table_1: torch.Tensor,
layer: RadixAttention,
metadata: NSAMetadata,
bs: int,
) -> torch.Tensor:
q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)
if layer.head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
kv_indptr = self.kv_indptr
non_minus1_mask = page_table_1 != -1
non_minus1_counts = non_minus1_mask.sum(dim=1)
kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0)
kv_indices = page_table_1[page_table_1 != -1]
mla_decode_fwd(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
kv_cache.view(-1, 1, 1, layer.head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
metadata.cu_seqlens_q,
kv_indptr,
kv_indices,
metadata.cu_seqlens_q,
metadata.max_seq_len_q,
layer.scaling,
layer.logit_cap,
)
# kv_cache = kv_cache.view(-1, 1, layer.head_dim)
return o
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> NSAIndexerMetadata:
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
from flash_mla import get_mla_metadata
flashmla_metadata, num_splits = get_mla_metadata(
cache_seqlens=cache_seqlens,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
num_heads_k=1,
num_heads_q=self.num_q_heads,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
topk=self.nsa_index_topk,
)
return NSAFlashMLAMetadata(
flashmla_metadata=flashmla_metadata,
num_splits=num_splits,
)