Files
sglang/python/sglang/srt/layers/attention/nsa/nsa_indexer.py

767 lines
26 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
if is_cuda():
try:
import deep_gemm
except ImportError as e:
deep_gemm = e
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
class BaseIndexerMetadata(ABC):
@abstractmethod
def get_seqlens_int32(self) -> torch.Tensor:
"""
Return: (batch_size,) int32 tensor
"""
@abstractmethod
def get_page_table_64(self) -> torch.Tensor:
"""
Return: (batch_size, num_blocks) int32, page table.
The page size of the table is 64.
"""
@abstractmethod
def get_seqlens_expanded(self) -> torch.Tensor:
"""
Return: (sum_extend_seq_len,) int32 tensor
"""
@abstractmethod
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""
Perform topk selection on the logits and possibly transform the result.
NOTE that attention backend may override this function to do some
transformation, which means the result of this topk_transform may not
be the topk indices of the input logits.
Return: Anything, since it will be passed to the attention backend
for further processing on sparse attention computation.
Don't assume it is the topk indices of the input logits.
"""
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
from fast_hadamard_transform import hadamard_transform
hidden_size = x.size(-1)
assert (
hidden_size & (hidden_size - 1)
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
return hadamard_transform(x, scale=hidden_size**-0.5)
class V32LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(
x.float(), (self.dim,), self.weight, self.bias, self.eps
).type_as(x)
class Indexer(CustomOp):
def __init__(
self,
hidden_size: int,
index_n_heads: int,
index_head_dim: int,
rope_head_dim: int,
index_topk: int,
q_lora_rank: int,
max_position_embeddings: int,
rope_theta: float,
layer_id: int,
scale_fmt: Optional[str],
block_size: int = 128,
rope_scaling: Optional[Dict[str, Any]] = None,
prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = index_n_heads
self.head_dim = index_head_dim
self.rope_head_dim = rope_head_dim
self.index_topk = index_topk
self.q_lora_rank = q_lora_rank
self.layer_id = layer_id
self.alt_stream = alt_stream
if is_cuda():
self.sm_count = deep_gemm.get_num_sms()
self.half_device_sm_count = align(self.sm_count // 2, 8)
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wq_b", prefix),
)
self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
self.k_norm = V32LayerNorm(self.head_dim)
# NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
prefix=add_prefix("weights_proj", prefix),
)
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
rotary_dim=rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
self.block_size = block_size
self.scale_fmt = scale_fmt
self.softmax_scale = self.head_dim**-0.5
def _forward_fake(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
):
bs = x.shape[0]
assert self.index_topk == 2048
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
None, ...
].repeat(bs, 1)
if forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.seq_lens_cpu is not None
)
which = 0
for i, (kv_len, qo_len) in enumerate(
zip(
forward_batch.seq_lens_cpu.tolist(),
forward_batch.extend_seq_lens_cpu,
strict=True,
)
):
for j in range(kv_len - qo_len, kv_len):
ans[which, j + 1 :] = -1
which += 1
assert which == ans.shape[0]
else:
assert forward_batch.seq_lens_cpu is not None
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
ans[i, seq_len:] = -1
return ans
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
def _get_q_k_bf16(
self,
q_lora: torch.Tensor,
x: torch.Tensor,
positions: torch.Tensor,
enable_dual_stream: bool,
):
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
self.half_device_sm_count
):
query, _ = self.wq_b(q_lora)
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
with torch.cuda.stream(self.alt_stream):
# TODO we should also put DeepGEMM half SM here?
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
current_stream.wait_stream(self.alt_stream)
else:
query, _ = self.wq_b(q_lora)
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
query[..., : self.rope_head_dim] = q_rope
key[..., : self.rope_head_dim] = k_rope
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
query = rotate_activation(query)
with torch.cuda.stream(self.alt_stream):
key = rotate_activation(key)
current_stream.wait_stream(self.alt_stream)
else:
query = rotate_activation(query)
key = rotate_activation(key)
return query, key
def _get_topk_paged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm
assert page_size == 64, "only support page size 64"
# NOTE(dark): this support extend/decode/decode+graph
block_tables = metadata.get_page_table_64()
max_seq_len = block_tables.shape[1] * page_size
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
layer_id=layer_id
)
blocksize = page_size
seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
seqlens_32, blocksize, self.sm_count
)
assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
assert len(kv_cache_fp8.shape) == 2
block_kv = 64
num_heads_kv = 1
head_dim_with_sf = 132
kv_cache_fp8 = kv_cache_fp8.view(
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
)
assert len(weights.shape) == 3
weights = weights.squeeze(2)
logits = deep_gemm.fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
seqlens_32,
block_tables,
schedule_metadata,
max_seq_len,
clean_logits=False,
)
# NOTE(dark): logits should be cleaned in topk_transform
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def _get_topk_ragged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
assert page_size == 64, "only support page size 64"
assert len(weights.shape) == 3
weights = weights.squeeze(-1)
k_fp8_list = []
k_scale_list = []
ks_list = []
offset = 0
block_tables = metadata.get_page_table_64()
assert (
forward_batch.seq_lens_cpu is not None
and forward_batch.extend_seq_lens_cpu is not None
)
for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens_cpu[i].item()
assert isinstance(seq_len, int)
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
layer_id,
seq_len,
block_tables[i],
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
offset += extend_seq_len
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded
logits = deep_gemm.fp8_mqa_logits(
q_fp8,
kv_fp8,
weights,
ks,
ke,
clean_logits=False,
)
assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def forward_indexer_bs_1(
self,
q_fp8: torch.Tensor,
weights: torch.Tensor,
forward_batch: ForwardBatch,
topk: int,
layer_id: int,
) -> Optional[torch.Tensor]:
if not is_npu():
from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index
page_size = forward_batch.token_to_kv_pool.page_size
assert page_size == 64, "only support page size 64"
assert len(weights.shape) == 3
weights = weights.squeeze(-1)
# logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
k_fp8_list = []
k_scale_list = []
topk_indices_list = []
block_tables = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :
]
strided_indices = torch.arange(
0, block_tables.shape[-1], page_size, device="cuda"
)
block_tables = block_tables[:, strided_indices] // page_size
q_len_start = 0
for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens[i].item()
q_len = (
forward_batch.extend_seq_lens_cpu[i]
if forward_batch.forward_mode.is_extend()
else 1
)
q_len_end = q_len_start + q_len
q_fp8_partial = q_fp8[q_len_start:q_len_end]
q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous()
weights_partial = weights[q_len_start:q_len_end]
weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous()
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous()
k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous()
index_score = fp8_index(
q_fp8_partial,
weights_partial,
k_fp8,
k_scale,
)
end_pos = seq_len
topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0)
pad_len = align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1]
topk_indices = torch.nn.functional.pad(
topk_indices, (0, pad_len), "constant", -1
)
topk_indices_list.append(topk_indices)
q_len_start = q_len_end
topk_indices = torch.cat(topk_indices_list, dim=0)
return topk_indices
def forward_indexer(
self,
q_fp8: torch.Tensor,
weights: torch.Tensor,
forward_batch: ForwardBatch,
topk: int,
layer_id: int,
) -> Optional[torch.Tensor]:
return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
def _forward(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
if is_hip():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
elif not is_npu():
from sglang.srt.layers.attention.nsa.triton_kernel import act_quant
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
metadata = forward_batch.attn_backend.get_indexer_metadata(
layer_id, forward_batch
)
enable_dual_stream = (
NSA_DUAL_STREAM
and self.alt_stream is not None
and get_is_capture_mode()
and q_lora.shape[0] > 0
and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
)
# skip NSA if attention backend choose to skip this batch
if metadata is None:
return None
if not NSA_USE_REAL_INDEXER: # temporary
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
with torch.cuda.stream(self.alt_stream):
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
current_stream.wait_stream(self.alt_stream)
else:
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
weights = self._get_logits_head_gate(x, q_scale)
if is_cuda():
assert forward_batch.seq_lens_cpu is not None
if len(forward_batch.seq_lens_cpu) == 0:
# this seems b/c max-pad, no worries?
# if x.shape[0] != 0:
# print(
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
# )
return torch.full(
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
)
if forward_batch.forward_mode.is_decode_or_idle():
topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata
)
else:
topk_result = self._get_topk_ragged(
forward_batch, layer_id, q_fp8, weights, metadata
)
else:
topk_result = self.forward_indexer(
q_fp8.contiguous(),
weights,
forward_batch,
topk=self.index_topk,
layer_id=layer_id,
)
return topk_result
def forward_cuda(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
return self._forward(x, q_lora, positions, forward_batch, layer_id)
def forward_npu(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> torch.Tensor:
import custom_ops
import torch_npu
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.utils import get_bool_env_var
if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = (
forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
)
enable_index_cp = (
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
)
is_prefill = forward_batch.forward_mode.is_extend()
attention_tp_rank = get_attention_tp_rank()
attention_tp_size = get_attention_tp_size()
cos_sin = self.rotary_emb.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
if is_prefill and enable_index_cp:
slice_length = cos.shape[0] // attention_tp_size
cos = cos[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
sin = sin[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
slot_mapping = forward_batch.out_cache_loc
block_table = forward_batch.attn_backend.forward_metadata.block_tables
bs = x.shape[0]
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
q_pe, q_nope = torch.split(
q,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64, 64 + 64]
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
bs, self.n_heads, self.rope_head_dim
) # [bs, n, d]
q = torch.cat([q_pe, q_nope], dim=-1)
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
k = self.k_norm(k_proj)
k_pe, k_nope = torch.split(
k,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64 + 64]
k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
bs, 1, self.rope_head_dim
) # [bs, 1, d]
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
if is_prefill and enable_index_cp:
k, local_k = (
torch.empty(
(k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
dtype=k.dtype,
device=k.device,
),
k,
)
get_attention_tp_group().all_gather_into_tensor(k, local_k)
forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
indexer_input = {}
if is_prefill:
actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
device=q.device
)
if enable_index_cp:
actual_seq_lengths_q -= bs * attention_tp_rank
actual_seq_lengths_q = torch.max(
actual_seq_lengths_q,
torch.zeros_like(actual_seq_lengths_q).to(
device=actual_seq_lengths_q.device
),
)
actual_seq_lengths_q = torch.min(
actual_seq_lengths_q,
torch.full(actual_seq_lengths_q.shape, bs).to(
device=actual_seq_lengths_q.device
),
)
else:
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
)
else:
actual_seq_lengths_q = (
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
)
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x)[0]
block_table = (
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
)
topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q.view(-1, self.n_heads, self.head_dim),
key=past_key_states,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=self.index_topk,
sparse_mode=3,
)
if is_prefill and enable_index_cp:
topk_indices, local_topk_indices = (
torch.empty(
(
topk_indices.shape[0] * attention_tp_size,
topk_indices.shape[1],
topk_indices.shape[2],
),
dtype=topk_indices.dtype,
device=topk_indices.device,
),
topk_indices,
)
get_attention_tp_group().all_gather_into_tensor(
topk_indices, local_topk_indices
)
return topk_indices