Support DeepSeek V3.2 Exp (#11061)
Co-authored-by: Stefan He <11166516+hebiao064@users.noreply.github.com> Co-authored-by: Liangsheng Yin <95566987+hnyls2002@users.noreply.github.com> Co-authored-by: Baizhou Zhang <56809903+fridge003@users.noreply.github.com> Co-authored-by: DarkSharpness <76582120+darksharpness@users.noreply.github.com> Co-authored-by: ZhengdQin <46387172+zhengdqin@users.noreply.github.com> Co-authored-by: DarkSharpness <2040703891@qq.com> Co-authored-by: hnyls2002 <lsyincs@gmail.com> Co-authored-by: Zhengda Qin <zhengdqin@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: HAI <hixiao@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -49,6 +49,30 @@ class ModelImpl(str, Enum):
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
|
||||
return (
|
||||
config.architectures is not None
|
||||
and config.architectures[0]
|
||||
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
|
||||
and getattr(config, "index_topk", None) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_head_dim
|
||||
|
||||
|
||||
def get_nsa_index_topk(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_topk
|
||||
|
||||
|
||||
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_n_heads
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -271,6 +295,7 @@ class ModelConfig:
|
||||
# FIXME: temporary special judge for MLA architecture
|
||||
if (
|
||||
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
||||
@@ -283,6 +308,11 @@ class ModelConfig:
|
||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_config.v_head_dim
|
||||
self.index_head_dim = (
|
||||
get_nsa_index_head_dim(self.hf_config)
|
||||
if is_deepseek_nsa(self.hf_config)
|
||||
else None
|
||||
)
|
||||
|
||||
# Handle rope scaling with yarn
|
||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
|
||||
@@ -2,9 +2,19 @@ import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
|
||||
try:
|
||||
from mf_adapter import TransferEngine
|
||||
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
import_error = e
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
||||
def __init__(
|
||||
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
||||
):
|
||||
try:
|
||||
from mf_adapter import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
if import_error is not None:
|
||||
logger.warning(
|
||||
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
||||
) from e
|
||||
)
|
||||
raise import_error
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = hostname
|
||||
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
||||
self.initialize()
|
||||
|
||||
def initialize(self) -> None:
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
)
|
||||
|
||||
transfer_protocol = self._get_transfer_protocol()
|
||||
if transfer_protocol is None or transfer_protocol == "sdma":
|
||||
trans_op_type = TransferEngine.TransDataOpType.SDMA
|
||||
else:
|
||||
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
|
||||
"""with device RDMA for PD transfer"""
|
||||
tmp_tensor = torch.zeros(1, device="npu")
|
||||
output_tensor_list = [
|
||||
torch.empty_like(tmp_tensor)
|
||||
for _ in range(get_tensor_model_parallel_world_size())
|
||||
]
|
||||
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
|
||||
torch.distributed.all_gather(
|
||||
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
|
||||
)
|
||||
"""Initialize the ascend transfer instance."""
|
||||
ret_value = self.engine.initialize(
|
||||
self.store_url,
|
||||
self.session_id,
|
||||
self.role,
|
||||
self.npu_id,
|
||||
self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
|
||||
)
|
||||
if ret_value != 0:
|
||||
logger.error("Ascend Transfer Engine initialization failed.")
|
||||
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
||||
ret_value = -1
|
||||
if ret_value != 0:
|
||||
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
||||
|
||||
@staticmethod
|
||||
def _get_transfer_protocol():
|
||||
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
|
||||
allowed_protocols = {"device_rdma", "sdma"}
|
||||
if protocol and protocol.lower() in allowed_protocols:
|
||||
return protocol.lower()
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid or no transfer protocol specified, using default protocol."
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -36,6 +36,8 @@ class ForwardMetadata:
|
||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_list: Optional[List[int]] = None
|
||||
seq_lens_list_cumsum: Optional[List[int]] = None
|
||||
seq_lens: Optional[torch.Tensor] = None
|
||||
actual_seq_lengths_q: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AscendAttnBackend(AttentionBackend):
|
||||
@@ -67,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
|
||||
if self.use_mla:
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.q_head_dim = (
|
||||
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
|
||||
)
|
||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||
self.graph_metadata = {}
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
@@ -102,10 +107,6 @@ class AscendAttnBackend(AttentionBackend):
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
|
||||
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.is_extend_in_batch:
|
||||
seq_lens_list_cumsum[-1] = (
|
||||
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
||||
) * tp_size
|
||||
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
||||
|
||||
self.graph_mode = False
|
||||
@@ -133,6 +134,10 @@ class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
||||
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
||||
metadata.seq_lens = seq_lens
|
||||
metadata.actual_seq_lengths_q = torch.tensor(
|
||||
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
|
||||
)
|
||||
|
||||
self.graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
@@ -161,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
|
||||
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
||||
metadata.block_tables[bs:, :].fill_(0)
|
||||
|
||||
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
self.graph_mode = True
|
||||
@@ -168,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_sparse(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
# For multi_head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: torch.Tensor = None,
|
||||
):
|
||||
|
||||
is_prefill = forward_batch.forward_mode.is_extend()
|
||||
|
||||
if save_kv_cache:
|
||||
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
||||
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
q_nope, q_pe = q, q_rope
|
||||
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
block_table = self.forward_metadata.block_tables
|
||||
if is_prefill:
|
||||
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
else:
|
||||
if self.forward_metadata.actual_seq_lengths_q is None:
|
||||
actual_seq_qlen = (
|
||||
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
|
||||
)
|
||||
else:
|
||||
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
|
||||
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||
actual_seq_lengths_kv = self.forward_metadata.seq_lens
|
||||
else:
|
||||
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
|
||||
|
||||
attn_out = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=layer.scaling,
|
||||
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
|
||||
block_table=block_table,
|
||||
sparse_block_size=1,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
return attn_out
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
@@ -176,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
# For multi_head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if topk_indices is not None:
|
||||
return self.forward_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
q_rope,
|
||||
k_rope,
|
||||
topk_indices,
|
||||
)
|
||||
if not self.use_mla:
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
@@ -437,10 +518,23 @@ class AscendAttnBackend(AttentionBackend):
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if is_mla_preprocess_enabled():
|
||||
# MLAPO does saving kv_cache
|
||||
save_kv_cache = False
|
||||
if topk_indices is not None:
|
||||
return self.forward_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
q_rope,
|
||||
k_rope,
|
||||
topk_indices,
|
||||
)
|
||||
|
||||
if self.graph_mode:
|
||||
return self.forward_decode_graph(
|
||||
|
||||
@@ -66,6 +66,13 @@ def create_ascend_backend(runner):
|
||||
return AscendAttnBackend(runner)
|
||||
|
||||
|
||||
@register_attention_backend("nsa")
|
||||
def create_nsa_backend(runner):
|
||||
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
||||
|
||||
return NativeSparseAttnBackend(runner)
|
||||
|
||||
|
||||
@register_attention_backend("triton")
|
||||
def create_triton_backend(runner):
|
||||
assert not runner.model_config.is_encoder_decoder, (
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
|
||||
def support_triton(self):
|
||||
"""Check if the current backend supports triton."""
|
||||
return True
|
||||
|
||||
def get_indexer_metadata(
|
||||
self,
|
||||
layer_id: int,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> Optional[BaseIndexerMetadata]:
|
||||
"""Get the indexer metadata. None means don't support indexer."""
|
||||
return None
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
|
||||
return backend.forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
def get_indexer_metadata(
|
||||
self, layer_id: int, forward_batch: ForwardBatch
|
||||
) -> Optional[BaseIndexerMetadata]:
|
||||
backend = self._select_backend(forward_batch.forward_mode)
|
||||
return backend.get_indexer_metadata(layer_id, forward_batch)
|
||||
|
||||
@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
||||
self.rotary_emb = rotary_emb
|
||||
self.layer_id = layer_id
|
||||
self.has_preprocess_weights = False
|
||||
self.dtype = None
|
||||
|
||||
self.q_lora_rank = self.q_b_proj.input_size # 1536
|
||||
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
|
||||
self.num_local_heads = num_local_heads # tp
|
||||
self.qk_nope_head_dim = qk_nope_head_dim # 128
|
||||
self.qk_rope_head_dim = qk_rope_head_dim # 64
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
|
||||
def preprocess_weights(self, hidden_states):
|
||||
self.dummy = torch.empty(
|
||||
@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
||||
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
|
||||
return k_cache, v_cache, slot_mapping
|
||||
|
||||
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||
def forward_absorb_prepare_npu_rms_norm_cache(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
):
|
||||
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
|
||||
self.dtype = hidden_states.dtype
|
||||
self.cos, self.sin = self.get_sin_cos(positions)
|
||||
self.kvCache, self.kvCacheRope, self.slotmapping = (
|
||||
self.get_kv_cache_and_cache_idx(forward_batch)
|
||||
)
|
||||
|
||||
if not self.has_preprocess_weights:
|
||||
self.has_preprocess_weights = True
|
||||
|
||||
cos, sin = self.cos, self.sin
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
|
||||
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q_lowrank)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
) # b*s,n,d
|
||||
|
||||
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
|
||||
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
|
||||
|
||||
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
|
||||
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
|
||||
|
||||
latent_cache = latent_cache.view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
) # (B*S,N,1,D)
|
||||
|
||||
cache_mode = "PA_BNSD"
|
||||
self.kvCache = self.kvCache.view(
|
||||
-1,
|
||||
forward_batch.attn_backend.page_size,
|
||||
1,
|
||||
forward_batch.attn_backend.kv_lora_rank,
|
||||
)
|
||||
self.kvCacheRope = self.kvCacheRope.view(
|
||||
-1,
|
||||
forward_batch.attn_backend.page_size,
|
||||
1,
|
||||
forward_batch.attn_backend.qk_rope_head_dim,
|
||||
)
|
||||
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
latent_cache,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
self.slotmapping.to(torch.int64),
|
||||
self.kvCacheRope,
|
||||
self.kvCache,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
|
||||
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
|
||||
|
||||
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||
input_dtype = hidden_states.dtype
|
||||
if not self.has_preprocess_weights:
|
||||
self.preprocess_weights(hidden_states)
|
||||
@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
||||
zero_allocator,
|
||||
positions,
|
||||
)
|
||||
|
||||
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||
_is_w8a8 = (
|
||||
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
|
||||
and self.qkv_a_proj.quant_method.quantization_config.get_name()
|
||||
== "w8a8_int8"
|
||||
)
|
||||
if _is_w8a8:
|
||||
return self.forward_mlapo(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb_prepare_npu_rms_norm_cache(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
|
||||
163
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
Normal file
163
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
|
||||
|
||||
|
||||
def dequantize_k_cache(quant_k_cache):
|
||||
if NSA_DEQUANT_K_CACHE_FAST:
|
||||
return _dequantize_k_cache_fast_wrapped(quant_k_cache)
|
||||
else:
|
||||
return _dequantize_k_cache_slow(quant_k_cache)
|
||||
|
||||
|
||||
def _dequantize_k_cache_slow(
|
||||
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
d: int = 576,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
De-quantize the k-cache
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, _ = quant_k_cache.shape
|
||||
assert h_k == 1
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
|
||||
)
|
||||
|
||||
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
|
||||
|
||||
input_nope = quant_k_cache[..., :dv]
|
||||
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
|
||||
result[..., dv:] = input_rope
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_nope = input_nope[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].to(torch.float32)
|
||||
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
|
||||
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_nope * cur_scales
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, d)
|
||||
return result
|
||||
|
||||
|
||||
def _dequantize_k_cache_fast_wrapped(
|
||||
quant_k_cache: torch.Tensor,
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
# TODO the final API may be 2D instead of 4D, thus we convert them here
|
||||
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
|
||||
assert dv == 512
|
||||
assert dim_quant == 656
|
||||
assert tile_size == 128
|
||||
quant_k_cache = quant_k_cache.view((-1, dim_quant))
|
||||
|
||||
output = _dequantize_k_cache_fast(quant_k_cache)
|
||||
|
||||
return output.view(num_blocks, block_size, 1, -1)
|
||||
|
||||
|
||||
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
|
||||
num_tokens, dim_quant = quant_k_cache.shape
|
||||
|
||||
assert quant_k_cache.dtype == torch.float8_e4m3fn
|
||||
dim_nope = 512
|
||||
dim_rope = 64
|
||||
num_tiles = dim_nope // group_size
|
||||
assert dim_quant == 656
|
||||
|
||||
output = torch.empty(
|
||||
(num_tokens, dim_nope + dim_rope),
|
||||
dtype=torch.bfloat16,
|
||||
device=quant_k_cache.device,
|
||||
)
|
||||
|
||||
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
|
||||
assert num_blocks_per_token == 5
|
||||
|
||||
assert dim_nope % group_size == 0
|
||||
NUM_NOPE_BLOCKS = dim_nope // group_size
|
||||
|
||||
input_nope_q = quant_k_cache[:, :dim_nope]
|
||||
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
|
||||
torch.float32
|
||||
)
|
||||
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
|
||||
|
||||
_dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
|
||||
output,
|
||||
input_nope_q,
|
||||
input_nope_s,
|
||||
input_rope,
|
||||
output.stride(0),
|
||||
input_nope_q.stride(0),
|
||||
input_nope_s.stride(0),
|
||||
input_rope.stride(0),
|
||||
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
|
||||
GROUP_SIZE=group_size,
|
||||
DIM_NOPE=dim_nope,
|
||||
DIM_ROPE=dim_rope,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dequantize_k_cache_fast_kernel(
|
||||
output_ptr,
|
||||
input_nope_q_ptr,
|
||||
input_nope_s_ptr,
|
||||
input_rope_ptr,
|
||||
output_stride_0: int,
|
||||
input_nope_q_stride_0: int,
|
||||
input_nope_s_stride_0: int,
|
||||
input_rope_stride_0: int,
|
||||
NUM_NOPE_BLOCKS: tl.constexpr,
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
DIM_NOPE: tl.constexpr,
|
||||
DIM_ROPE: tl.constexpr,
|
||||
):
|
||||
token_id = tl.program_id(0)
|
||||
raw_block_id = tl.program_id(1)
|
||||
|
||||
if raw_block_id < NUM_NOPE_BLOCKS:
|
||||
# a. dequant nope
|
||||
effective_block_id = raw_block_id
|
||||
|
||||
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs_q < DIM_NOPE
|
||||
ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
|
||||
ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
|
||||
|
||||
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
|
||||
y_s = tl.load(ptr_s)
|
||||
|
||||
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
|
||||
|
||||
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
|
||||
tl.store(dst_ptr, y, mask=mask)
|
||||
else:
|
||||
# b. copy rope
|
||||
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
|
||||
|
||||
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs < DIM_ROPE
|
||||
|
||||
src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
|
||||
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
|
||||
|
||||
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
|
||||
tl.store(dst_ptr, data, mask=mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise Exception("UT is in quant_k_cache.py")
|
||||
354
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Normal file
354
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Normal file
@@ -0,0 +1,354 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
|
||||
"""
|
||||
k: data, 128 item per token, fp8
|
||||
s: scale, 1 item per token, fp32
|
||||
"""
|
||||
|
||||
|
||||
class GetK:
|
||||
@classmethod
|
||||
def execute(cls, *args, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
|
||||
seq_len_ = num_pages * pool.page_size
|
||||
index_k_fp8 = torch.empty(
|
||||
(seq_len_, pool.index_head_dim),
|
||||
dtype=torch.uint8,
|
||||
device=pool.device,
|
||||
)
|
||||
for i in range(num_pages):
|
||||
page_index = page_indices[i]
|
||||
index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
|
||||
page_index
|
||||
][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
|
||||
|
||||
return index_k_fp8[:seq_len]
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
"""
|
||||
:param page_indices: (num_pages,), int32
|
||||
:return: (seq_len, index_head_dim), uint8
|
||||
"""
|
||||
|
||||
# can handle per 128B instead of per element
|
||||
|
||||
# page_indices: (num_pages,), element := a page index
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
|
||||
num_k_bytes_per_page = pool.page_size * pool.index_head_dim
|
||||
num_k_bytes_per_token = pool.index_head_dim
|
||||
|
||||
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
|
||||
# flat_buf: (whatever,), uint8
|
||||
flat_buf = buf.flatten()
|
||||
|
||||
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
|
||||
flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
|
||||
num_k_bytes_per_page, dtype=torch.int32, device="cuda"
|
||||
)[None, :]
|
||||
flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
|
||||
|
||||
out = flat_buf[flat_indices]
|
||||
return out.view(-1, 128)
|
||||
|
||||
|
||||
class GetS:
|
||||
@classmethod
|
||||
def execute(cls, *args, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
|
||||
seq_len_ = num_pages * pool.page_size
|
||||
assert pool.index_head_dim // pool.quant_block_size == 1
|
||||
index_k_scale_fp8 = torch.empty(
|
||||
(seq_len_, 4),
|
||||
dtype=torch.uint8,
|
||||
device=pool.device,
|
||||
)
|
||||
for i in range(num_pages):
|
||||
page_index = page_indices[i]
|
||||
index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
|
||||
page_index
|
||||
][pool.page_size * pool.index_head_dim :].view(-1, 4)
|
||||
return index_k_scale_fp8[:seq_len]
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
"""
|
||||
:param page_indices: (num_pages,), int32
|
||||
:return: (seq_len, index_head_dim // quant_block_size), uint8
|
||||
"""
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
|
||||
num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
|
||||
num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
|
||||
s_offset_in_page = pool.page_size * pool.index_head_dim
|
||||
|
||||
flat_buf = buf.flatten()
|
||||
flat_indices = (
|
||||
(page_indices * buf_numel_per_page)[:, None]
|
||||
+ torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
|
||||
None, :
|
||||
]
|
||||
+ s_offset_in_page
|
||||
)
|
||||
flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
|
||||
|
||||
out = flat_buf[flat_indices]
|
||||
return out.view(-1, 4)
|
||||
|
||||
|
||||
class SetK:
|
||||
@classmethod
|
||||
def execute(cls, *args, buf, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs, buf=buf)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
):
|
||||
for i in range(len(loc)):
|
||||
page_index = loc[i] // pool.page_size
|
||||
offset = loc[i] % pool.page_size
|
||||
buf[
|
||||
page_index,
|
||||
offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
|
||||
] = index_k[i].view(torch.uint8)
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
):
|
||||
(num_tokens_to_write,) = loc.shape
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
num_k_bytes_per_token = pool.index_head_dim
|
||||
|
||||
# loc: (num_tokens_to_write,), int32, element := the token index to write to
|
||||
loc_page_index = loc // pool.page_size
|
||||
loc_token_offset_in_page = loc % pool.page_size
|
||||
|
||||
flat_buf = buf.flatten()
|
||||
flat_indices = (
|
||||
(loc_page_index * buf_numel_per_page)[:, None]
|
||||
+ (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
|
||||
+ torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
|
||||
None, :
|
||||
]
|
||||
)
|
||||
num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
|
||||
flat_indices = flat_indices.flatten()[:num_k_bytes_total]
|
||||
flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
|
||||
|
||||
|
||||
class SetS:
|
||||
@classmethod
|
||||
def execute(cls, *args, buf, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs, buf=buf)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k_scale: torch.Tensor,
|
||||
):
|
||||
for i in range(len(loc)):
|
||||
page_index = loc[i] // pool.page_size
|
||||
offset = loc[i] % pool.page_size
|
||||
start = pool.page_size * pool.index_head_dim
|
||||
buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
|
||||
index_k_scale[i].view(torch.uint8)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k_scale: torch.Tensor,
|
||||
):
|
||||
(num_tokens_to_write,) = loc.shape
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
num_s_bytes_per_token = 4
|
||||
s_offset_in_page = pool.page_size * pool.index_head_dim
|
||||
|
||||
# loc: (num_tokens_to_write,), int32, element := the token index to write to
|
||||
loc_page_index = loc // pool.page_size
|
||||
loc_token_offset_in_page = loc % pool.page_size
|
||||
|
||||
flat_buf = buf.flatten()
|
||||
flat_indices = (
|
||||
(loc_page_index * buf_numel_per_page)[:, None]
|
||||
+ s_offset_in_page
|
||||
+ (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
|
||||
+ torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
|
||||
None, :
|
||||
]
|
||||
)
|
||||
number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
|
||||
flat_indices = flat_indices.flatten()[:number_s_bytes_total]
|
||||
flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
|
||||
|
||||
|
||||
class SetKAndS:
|
||||
@classmethod
|
||||
def execute(cls, *args, buf, **kwargs):
|
||||
if 0:
|
||||
# print("SetK, SetS comparison test")
|
||||
buf_cloned = buf.clone()
|
||||
cls.vanilla(*args, **kwargs, buf=buf)
|
||||
cls.triton(*args, **kwargs, buf=buf_cloned)
|
||||
|
||||
def _clear_token_0(target):
|
||||
target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
|
||||
|
||||
_clear_token_0(buf)
|
||||
_clear_token_0(buf_cloned)
|
||||
|
||||
assert torch.all(
|
||||
buf == buf_cloned
|
||||
), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
|
||||
return
|
||||
|
||||
cls.triton(*args, **kwargs, buf=buf)
|
||||
|
||||
@classmethod
|
||||
def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
|
||||
SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
|
||||
SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
|
||||
|
||||
@classmethod
|
||||
def triton(cls, pool, buf, loc, index_k, index_k_scale):
|
||||
_set_k_and_s_triton(
|
||||
buf=buf,
|
||||
loc=loc,
|
||||
index_k=index_k,
|
||||
index_k_scale=index_k_scale,
|
||||
page_size=pool.page_size,
|
||||
)
|
||||
|
||||
|
||||
def _set_k_and_s_triton(
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
index_k_scale: torch.Tensor,
|
||||
page_size: int,
|
||||
):
|
||||
"""
|
||||
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
|
||||
:param loc: (num_tokens_to_write,), int, element := the token index to write to
|
||||
:param index_k: (num_tokens_to_write, 128 elem), fp8
|
||||
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
|
||||
:return:
|
||||
"""
|
||||
num_pages, buf_numel_per_page = buf.shape
|
||||
(num_tokens_to_write,) = loc.shape
|
||||
num_tokens_to_write_, index_head_dim = index_k.shape
|
||||
num_tokens_to_write__, scale_dim = index_k_scale.shape
|
||||
assert buf_numel_per_page == 64 * (128 + 4)
|
||||
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
|
||||
assert index_head_dim == 128
|
||||
assert scale_dim == 1
|
||||
assert page_size == 64
|
||||
|
||||
assert buf.dtype == torch.uint8
|
||||
assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
|
||||
assert index_k.dtype == torch.float8_e4m3fn
|
||||
assert index_k_scale.dtype == torch.float32
|
||||
|
||||
assert buf.is_contiguous()
|
||||
assert loc.is_contiguous()
|
||||
assert index_k.is_contiguous()
|
||||
assert index_k_scale.is_contiguous()
|
||||
|
||||
buf_fp8 = buf.view(torch.float8_e4m3fn)
|
||||
buf_fp32 = buf.view(torch.float32)
|
||||
|
||||
_set_k_and_s_triton_kernel[(num_tokens_to_write,)](
|
||||
buf_fp8,
|
||||
buf_fp32,
|
||||
loc,
|
||||
index_k,
|
||||
index_k_scale,
|
||||
index_k.stride(0),
|
||||
PAGE_SIZE=page_size,
|
||||
BUF_NUMEL_PER_PAGE=buf_numel_per_page,
|
||||
NUM_K_ELEMS_PER_TOKEN=index_head_dim,
|
||||
S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _set_k_and_s_triton_kernel(
|
||||
buf_fp8_ptr,
|
||||
buf_fp32_ptr,
|
||||
loc_ptr,
|
||||
index_k_ptr,
|
||||
index_k_scale_ptr,
|
||||
index_k_ptr_stride_0,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
BUF_NUMEL_PER_PAGE: tl.constexpr,
|
||||
NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
|
||||
S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
|
||||
):
|
||||
token_id = tl.program_id(0)
|
||||
|
||||
loc = tl.load(loc_ptr + token_id)
|
||||
|
||||
in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
|
||||
|
||||
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
|
||||
k = tl.load(index_k_ptr + in_k_offsets)
|
||||
k_scale = tl.load(index_k_scale_ptr + token_id)
|
||||
|
||||
loc_page_index = loc // PAGE_SIZE
|
||||
loc_token_offset_in_page = loc % PAGE_SIZE
|
||||
|
||||
out_k_offsets = (
|
||||
loc_page_index * BUF_NUMEL_PER_PAGE
|
||||
+ loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
|
||||
+ tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
|
||||
)
|
||||
|
||||
# "//4" b/c it is fp32 instead of uint8
|
||||
out_s_offset = (
|
||||
loc_page_index * BUF_NUMEL_PER_PAGE // 4
|
||||
+ S_OFFSET_NBYTES_IN_PAGE // 4
|
||||
+ loc_token_offset_in_page
|
||||
)
|
||||
|
||||
tl.store(buf_fp8_ptr + out_k_offsets, k)
|
||||
tl.store(buf_fp32_ptr + out_s_offset, k_scale)
|
||||
761
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Normal file
761
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Normal file
@@ -0,0 +1,761 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
|
||||
|
||||
if is_cuda():
|
||||
import deep_gemm
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
|
||||
|
||||
|
||||
class BaseIndexerMetadata(ABC):
|
||||
@abstractmethod
|
||||
def get_seqlens_int32(self) -> torch.Tensor:
|
||||
"""
|
||||
Return: (batch_size,) int32 tensor
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_page_table_64(self) -> torch.Tensor:
|
||||
"""
|
||||
Return: (batch_size, num_blocks) int32, page table.
|
||||
The page size of the table is 64.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_seqlens_expanded(self) -> torch.Tensor:
|
||||
"""
|
||||
Return: (sum_extend_seq_len,) int32 tensor
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def topk_transform(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform topk selection on the logits and possibly transform the result.
|
||||
|
||||
NOTE that attention backend may override this function to do some
|
||||
transformation, which means the result of this topk_transform may not
|
||||
be the topk indices of the input logits.
|
||||
|
||||
Return: Anything, since it will be passed to the attention backend
|
||||
for further processing on sparse attention computation.
|
||||
Don't assume it is the topk indices of the input logits.
|
||||
"""
|
||||
|
||||
|
||||
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dtype == torch.bfloat16
|
||||
from fast_hadamard_transform import hadamard_transform
|
||||
|
||||
hidden_size = x.size(-1)
|
||||
assert (
|
||||
hidden_size & (hidden_size - 1)
|
||||
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
|
||||
return hadamard_transform(x, scale=hidden_size**-0.5)
|
||||
|
||||
|
||||
class V32LayerNorm(nn.Module):
|
||||
"""
|
||||
Layer Normalization.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.layer_norm(
|
||||
x.float(), (self.dim,), self.weight, self.bias, self.eps
|
||||
).type_as(x)
|
||||
|
||||
|
||||
class Indexer(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
index_n_heads: int,
|
||||
index_head_dim: int,
|
||||
rope_head_dim: int,
|
||||
index_topk: int,
|
||||
q_lora_rank: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
layer_id: int,
|
||||
scale_fmt: Optional[str],
|
||||
block_size: int = 128,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
prefix: str = "",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.n_heads = index_n_heads
|
||||
self.head_dim = index_head_dim
|
||||
self.rope_head_dim = rope_head_dim
|
||||
self.index_topk = index_topk
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.layer_id = layer_id
|
||||
self.alt_stream = alt_stream
|
||||
if is_cuda():
|
||||
self.sm_count = deep_gemm.get_num_sms()
|
||||
self.half_device_sm_count = align(self.sm_count // 2, 8)
|
||||
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("wq_b", prefix),
|
||||
)
|
||||
self.wk = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("wk", prefix),
|
||||
)
|
||||
self.k_norm = V32LayerNorm(self.head_dim)
|
||||
# NOTE: weight_proj is not quantized
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
prefix=add_prefix("weights_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
rope_head_dim,
|
||||
rotary_dim=rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta, # type: ignore
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
self.block_size = block_size
|
||||
self.scale_fmt = scale_fmt
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
def _forward_fake(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
):
|
||||
bs = x.shape[0]
|
||||
assert self.index_topk == 2048
|
||||
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
|
||||
None, ...
|
||||
].repeat(bs, 1)
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
assert (
|
||||
forward_batch.extend_seq_lens_cpu is not None
|
||||
and forward_batch.seq_lens_cpu is not None
|
||||
)
|
||||
which = 0
|
||||
for i, (kv_len, qo_len) in enumerate(
|
||||
zip(
|
||||
forward_batch.seq_lens_cpu.tolist(),
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
strict=True,
|
||||
)
|
||||
):
|
||||
for j in range(kv_len - qo_len, kv_len):
|
||||
ans[which, j + 1 :] = -1
|
||||
which += 1
|
||||
assert which == ans.shape[0]
|
||||
else:
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
|
||||
ans[i, seq_len:] = -1
|
||||
|
||||
return ans
|
||||
|
||||
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = weights * self.n_heads**-0.5
|
||||
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
||||
return weights
|
||||
|
||||
def _get_q_k_bf16(
|
||||
self,
|
||||
q_lora: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
enable_dual_stream: bool,
|
||||
):
|
||||
|
||||
if enable_dual_stream:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
|
||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||
self.half_device_sm_count
|
||||
):
|
||||
query, _ = self.wq_b(q_lora)
|
||||
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
|
||||
q_rope, _ = torch.split(
|
||||
query,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# TODO we should also put DeepGEMM half SM here?
|
||||
key, _ = self.wk(x)
|
||||
key = self.k_norm(key)
|
||||
|
||||
k_rope, _ = torch.split(
|
||||
key,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
query, _ = self.wq_b(q_lora)
|
||||
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
|
||||
|
||||
q_rope, _ = torch.split(
|
||||
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
key, _ = self.wk(x)
|
||||
key = self.k_norm(key)
|
||||
k_rope, _ = torch.split(
|
||||
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
|
||||
|
||||
query[..., : self.rope_head_dim] = q_rope
|
||||
key[..., : self.rope_head_dim] = k_rope
|
||||
|
||||
if enable_dual_stream:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
query = rotate_activation(query)
|
||||
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
key = rotate_activation(key)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
query = rotate_activation(query)
|
||||
key = rotate_activation(key)
|
||||
|
||||
return query, key
|
||||
|
||||
def _get_topk_paged(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
metadata: BaseIndexerMetadata,
|
||||
) -> torch.Tensor:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||
|
||||
page_size = forward_batch.token_to_kv_pool.page_size
|
||||
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm
|
||||
assert page_size == 64, "only support page size 64"
|
||||
|
||||
# NOTE(dark): this support extend/decode/decode+graph
|
||||
block_tables = metadata.get_page_table_64()
|
||||
|
||||
max_seq_len = block_tables.shape[1] * page_size
|
||||
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
|
||||
layer_id=layer_id
|
||||
)
|
||||
|
||||
blocksize = page_size
|
||||
seqlens_32 = metadata.get_seqlens_int32()
|
||||
# NOTE(dark): 132 is SM count on H200/B200, not magic number
|
||||
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
|
||||
seqlens_32, blocksize, self.sm_count
|
||||
)
|
||||
|
||||
assert len(q_fp8.shape) == 3
|
||||
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
|
||||
assert len(kv_cache_fp8.shape) == 2
|
||||
block_kv = 64
|
||||
num_heads_kv = 1
|
||||
head_dim_with_sf = 132
|
||||
kv_cache_fp8 = kv_cache_fp8.view(
|
||||
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
|
||||
)
|
||||
assert len(weights.shape) == 3
|
||||
weights = weights.squeeze(2)
|
||||
|
||||
logits = deep_gemm.fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
seqlens_32,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_seq_len,
|
||||
clean_logits=False,
|
||||
)
|
||||
|
||||
# NOTE(dark): logits should be cleaned in topk_transform
|
||||
topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||
return topk_result
|
||||
|
||||
def _get_topk_ragged(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
metadata: BaseIndexerMetadata,
|
||||
) -> torch.Tensor:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||
|
||||
page_size = forward_batch.token_to_kv_pool.page_size
|
||||
assert page_size == 64, "only support page size 64"
|
||||
assert len(weights.shape) == 3
|
||||
weights = weights.squeeze(-1)
|
||||
k_fp8_list = []
|
||||
k_scale_list = []
|
||||
ks_list = []
|
||||
offset = 0
|
||||
|
||||
block_tables = metadata.get_page_table_64()
|
||||
|
||||
assert (
|
||||
forward_batch.seq_lens_cpu is not None
|
||||
and forward_batch.extend_seq_lens_cpu is not None
|
||||
)
|
||||
|
||||
for i in range(forward_batch.batch_size):
|
||||
seq_len = forward_batch.seq_lens_cpu[i].item()
|
||||
assert isinstance(seq_len, int)
|
||||
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
|
||||
layer_id,
|
||||
seq_len,
|
||||
block_tables[i],
|
||||
)
|
||||
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
|
||||
layer_id,
|
||||
seq_len,
|
||||
block_tables[i],
|
||||
)
|
||||
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
|
||||
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
|
||||
k_fp8_list.append(k_fp8)
|
||||
k_scale_list.append(k_scale)
|
||||
ks_list.append(ks)
|
||||
offset += extend_seq_len
|
||||
|
||||
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
|
||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
|
||||
kv_fp8 = (k_fp8, k_scale)
|
||||
ks = torch.cat(ks_list, dim=0)
|
||||
seq_lens_expanded = metadata.get_seqlens_expanded()
|
||||
ke = ks + seq_lens_expanded
|
||||
|
||||
logits = deep_gemm.fp8_mqa_logits(
|
||||
q_fp8,
|
||||
kv_fp8,
|
||||
weights,
|
||||
ks,
|
||||
ke,
|
||||
clean_logits=False,
|
||||
)
|
||||
|
||||
assert logits.shape[0] == len(seq_lens_expanded)
|
||||
topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||
|
||||
return topk_result
|
||||
|
||||
def forward_indexer_bs_1(
|
||||
self,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
topk: int,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if not is_npu():
|
||||
from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index
|
||||
|
||||
page_size = forward_batch.token_to_kv_pool.page_size
|
||||
assert page_size == 64, "only support page size 64"
|
||||
|
||||
assert len(weights.shape) == 3
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
# logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
k_fp8_list = []
|
||||
k_scale_list = []
|
||||
|
||||
topk_indices_list = []
|
||||
|
||||
block_tables = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, :
|
||||
]
|
||||
strided_indices = torch.arange(
|
||||
0, block_tables.shape[-1], page_size, device="cuda"
|
||||
)
|
||||
block_tables = block_tables[:, strided_indices] // page_size
|
||||
|
||||
q_len_start = 0
|
||||
|
||||
for i in range(forward_batch.batch_size):
|
||||
seq_len = forward_batch.seq_lens[i].item()
|
||||
q_len = (
|
||||
forward_batch.extend_seq_lens_cpu[i]
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else 1
|
||||
)
|
||||
q_len_end = q_len_start + q_len
|
||||
|
||||
q_fp8_partial = q_fp8[q_len_start:q_len_end]
|
||||
q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous()
|
||||
|
||||
weights_partial = weights[q_len_start:q_len_end]
|
||||
weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous()
|
||||
|
||||
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
|
||||
layer_id,
|
||||
seq_len,
|
||||
block_tables[i],
|
||||
)
|
||||
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
|
||||
layer_id,
|
||||
seq_len,
|
||||
block_tables[i],
|
||||
)
|
||||
|
||||
k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous()
|
||||
k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous()
|
||||
|
||||
index_score = fp8_index(
|
||||
q_fp8_partial,
|
||||
weights_partial,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
)
|
||||
end_pos = seq_len
|
||||
topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0)
|
||||
|
||||
pad_len = align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1]
|
||||
topk_indices = torch.nn.functional.pad(
|
||||
topk_indices, (0, pad_len), "constant", -1
|
||||
)
|
||||
|
||||
topk_indices_list.append(topk_indices)
|
||||
|
||||
q_len_start = q_len_end
|
||||
|
||||
topk_indices = torch.cat(topk_indices_list, dim=0)
|
||||
|
||||
return topk_indices
|
||||
|
||||
def forward_indexer(
|
||||
self,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
topk: int,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if not is_npu():
|
||||
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||
|
||||
metadata = forward_batch.attn_backend.get_indexer_metadata(
|
||||
layer_id, forward_batch
|
||||
)
|
||||
|
||||
enable_dual_stream = (
|
||||
NSA_DUAL_STREAM
|
||||
and self.alt_stream is not None
|
||||
and get_is_capture_mode()
|
||||
and q_lora.shape[0] > 0
|
||||
and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
|
||||
# skip NSA if attention backend choose to skip this batch
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
if not NSA_USE_REAL_INDEXER: # temporary
|
||||
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
|
||||
|
||||
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
|
||||
|
||||
if enable_dual_stream:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
|
||||
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
|
||||
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
|
||||
|
||||
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
|
||||
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
|
||||
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
|
||||
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
|
||||
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
|
||||
layer_id=layer_id,
|
||||
loc=forward_batch.out_cache_loc,
|
||||
index_k=k_fp8,
|
||||
index_k_scale=k_scale,
|
||||
)
|
||||
|
||||
weights = self._get_logits_head_gate(x, q_scale)
|
||||
|
||||
if is_cuda():
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
if len(forward_batch.seq_lens_cpu) == 0:
|
||||
# this seems b/c max-pad, no worries?
|
||||
# if x.shape[0] != 0:
|
||||
# print(
|
||||
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
|
||||
# )
|
||||
return torch.full(
|
||||
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
|
||||
)
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
topk_result = self._get_topk_paged(
|
||||
forward_batch, layer_id, q_fp8, weights, metadata
|
||||
)
|
||||
else:
|
||||
topk_result = self._get_topk_ragged(
|
||||
forward_batch, layer_id, q_fp8, weights, metadata
|
||||
)
|
||||
else:
|
||||
topk_result = self.forward_indexer(
|
||||
q_fp8.contiguous(),
|
||||
weights,
|
||||
forward_batch,
|
||||
topk=self.index_topk,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
return topk_result
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self._forward(x, q_lora, positions, forward_batch, layer_id)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
import custom_ops
|
||||
import torch_npu
|
||||
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
|
||||
actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
|
||||
else:
|
||||
actual_seq_lengths_kv = (
|
||||
forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
|
||||
)
|
||||
enable_index_cp = (
|
||||
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
|
||||
)
|
||||
is_prefill = forward_batch.forward_mode.is_extend()
|
||||
|
||||
attention_tp_rank = get_attention_tp_rank()
|
||||
attention_tp_size = get_attention_tp_size()
|
||||
|
||||
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
|
||||
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
|
||||
if is_prefill and enable_index_cp:
|
||||
slice_length = cos.shape[0] // attention_tp_size
|
||||
cos = cos[
|
||||
slice_length
|
||||
* attention_tp_rank : slice_length
|
||||
* (attention_tp_rank + 1)
|
||||
]
|
||||
sin = sin[
|
||||
slice_length
|
||||
* attention_tp_rank : slice_length
|
||||
* (attention_tp_rank + 1)
|
||||
]
|
||||
|
||||
slot_mapping = forward_batch.out_cache_loc
|
||||
block_table = forward_batch.attn_backend.forward_metadata.block_tables
|
||||
|
||||
bs = x.shape[0]
|
||||
|
||||
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
|
||||
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
|
||||
q_pe, q_nope = torch.split(
|
||||
q,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
) # [bs, 64, 64 + 64]
|
||||
|
||||
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
|
||||
bs, self.n_heads, self.rope_head_dim
|
||||
) # [bs, n, d]
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
|
||||
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
|
||||
k = self.k_norm(k_proj)
|
||||
k_pe, k_nope = torch.split(
|
||||
k,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
) # [bs, 64 + 64]
|
||||
|
||||
k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
|
||||
bs, 1, self.rope_head_dim
|
||||
) # [bs, 1, d]
|
||||
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
|
||||
|
||||
if is_prefill and enable_index_cp:
|
||||
k, local_k = (
|
||||
torch.empty(
|
||||
(k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
|
||||
dtype=k.dtype,
|
||||
device=k.device,
|
||||
),
|
||||
k,
|
||||
)
|
||||
get_attention_tp_group().all_gather_into_tensor(k, local_k)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
|
||||
|
||||
indexer_input = {}
|
||||
if is_prefill:
|
||||
actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
|
||||
actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
|
||||
device=q.device
|
||||
)
|
||||
if enable_index_cp:
|
||||
actual_seq_lengths_q -= bs * attention_tp_rank
|
||||
actual_seq_lengths_q = torch.max(
|
||||
actual_seq_lengths_q,
|
||||
torch.zeros_like(actual_seq_lengths_q).to(
|
||||
device=actual_seq_lengths_q.device
|
||||
),
|
||||
)
|
||||
actual_seq_lengths_q = torch.min(
|
||||
actual_seq_lengths_q,
|
||||
torch.full(actual_seq_lengths_q.shape, bs).to(
|
||||
device=actual_seq_lengths_q.device
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
|
||||
actual_seq_lengths_q = torch.tensor(
|
||||
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
|
||||
)
|
||||
else:
|
||||
actual_seq_lengths_q = (
|
||||
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
|
||||
)
|
||||
|
||||
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
|
||||
|
||||
x = x.view(-1, self.hidden_size)
|
||||
weights = self.weights_proj(x)[0]
|
||||
block_table = (
|
||||
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
|
||||
)
|
||||
|
||||
topk_indices = torch.ops.custom.npu_lightning_indexer(
|
||||
query=q.view(-1, self.n_heads, self.head_dim),
|
||||
key=past_key_states,
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
|
||||
actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=self.index_topk,
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
if is_prefill and enable_index_cp:
|
||||
topk_indices, local_topk_indices = (
|
||||
torch.empty(
|
||||
(
|
||||
topk_indices.shape[0] * attention_tp_size,
|
||||
topk_indices.shape[1],
|
||||
topk_indices.shape[2],
|
||||
),
|
||||
dtype=topk_indices.dtype,
|
||||
device=topk_indices.device,
|
||||
),
|
||||
topk_indices,
|
||||
)
|
||||
get_attention_tp_group().all_gather_into_tensor(
|
||||
topk_indices, local_topk_indices
|
||||
)
|
||||
|
||||
return topk_indices
|
||||
255
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
Normal file
255
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
|
||||
|
||||
|
||||
def quantize_k_cache(cache_k):
|
||||
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
|
||||
if NSA_QUANT_K_CACHE_FAST:
|
||||
return _quantize_k_cache_fast_wrapped(cache_k)
|
||||
else:
|
||||
return _quantize_k_cache_slow(cache_k)
|
||||
|
||||
|
||||
# Copied from original
|
||||
def _quantize_k_cache_slow(
|
||||
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Quantize the k-cache
|
||||
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
|
||||
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, d = input_k_cache.shape
|
||||
assert h_k == 1
|
||||
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
|
||||
input_elem_size = input_k_cache.element_size()
|
||||
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device=input_k_cache.device,
|
||||
)
|
||||
result_k_nope_part = result[..., :dv]
|
||||
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
|
||||
result_k_rope_part[:] = input_k_cache[..., dv:]
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_scale_factors_inv = (
|
||||
torch.abs(
|
||||
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
|
||||
)
|
||||
.max(dim=-1)
|
||||
.values
|
||||
/ 448.0
|
||||
) # [num_blocks, block_size]
|
||||
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
|
||||
|
||||
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
|
||||
cur_quantized_nope = (
|
||||
input_k_cache[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].float()
|
||||
/ cur_scale_factors_inv.float()
|
||||
).to(torch.float8_e4m3fn)
|
||||
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_quantized_nope
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, -1)
|
||||
return result
|
||||
|
||||
|
||||
def _quantize_k_cache_fast_wrapped(
|
||||
input_k_cache: torch.Tensor,
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
# TODO the final API may be 2D instead of 4D, thus we convert them here
|
||||
num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
|
||||
assert dv == 512
|
||||
assert dim_nope_and_rope == 512 + 64
|
||||
assert tile_size == 128
|
||||
input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
|
||||
|
||||
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
|
||||
k_nope = input_k_cache[:, :dv]
|
||||
k_rope = input_k_cache[:, dv:]
|
||||
|
||||
output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
|
||||
|
||||
return output.view(num_blocks, block_size, 1, -1)
|
||||
|
||||
|
||||
def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
|
||||
"""
|
||||
:param k_nope: (num_tokens, dim_nope 512)
|
||||
:param k_rope: (num_tokens, dim_rope 64)
|
||||
"""
|
||||
|
||||
assert k_nope.dtype == torch.bfloat16
|
||||
assert k_rope.dtype == torch.bfloat16
|
||||
|
||||
num_tokens, dim_nope = k_nope.shape
|
||||
num_tokens_, dim_rope = k_rope.shape
|
||||
assert num_tokens == num_tokens_
|
||||
assert dim_nope == 512
|
||||
assert dim_rope == 64
|
||||
assert k_nope.dtype == k_rope.dtype
|
||||
num_tiles = dim_nope // group_size
|
||||
|
||||
assert k_nope.stride(1) == 1
|
||||
assert k_rope.stride(1) == 1
|
||||
|
||||
output = torch.empty(
|
||||
(num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device=k_nope.device,
|
||||
)
|
||||
output_nope_q = output[..., :dim_nope]
|
||||
output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
|
||||
output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
|
||||
|
||||
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
|
||||
assert num_blocks_per_token == 5
|
||||
|
||||
assert dim_nope % group_size == 0
|
||||
NUM_NOPE_BLOCKS = dim_nope // group_size
|
||||
|
||||
_quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
|
||||
output_nope_q,
|
||||
output_nope_s,
|
||||
output_rope,
|
||||
k_nope,
|
||||
k_rope,
|
||||
output_nope_q.stride(0),
|
||||
output_nope_s.stride(0),
|
||||
output_rope.stride(0),
|
||||
k_nope.stride(0),
|
||||
k_rope.stride(0),
|
||||
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
|
||||
GROUP_SIZE=group_size,
|
||||
DIM_NOPE=dim_nope,
|
||||
DIM_ROPE=dim_rope,
|
||||
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
|
||||
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _quantize_k_cache_fast_kernel(
|
||||
output_nope_q_ptr,
|
||||
output_nope_s_ptr,
|
||||
output_rope_ptr,
|
||||
k_nope_ptr,
|
||||
k_rope_ptr,
|
||||
output_nope_q_stride_0: int,
|
||||
output_nope_s_stride_0: int,
|
||||
output_rope_stride_0: int,
|
||||
k_nope_stride_0: int,
|
||||
k_rope_stride_0: int,
|
||||
NUM_NOPE_BLOCKS: tl.constexpr,
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
DIM_NOPE: tl.constexpr,
|
||||
DIM_ROPE: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr,
|
||||
FP8_MAX: tl.constexpr,
|
||||
):
|
||||
token_id = tl.program_id(0)
|
||||
raw_block_id = tl.program_id(1)
|
||||
|
||||
if raw_block_id < NUM_NOPE_BLOCKS:
|
||||
# a. quant nope
|
||||
effective_block_id = raw_block_id
|
||||
|
||||
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs < DIM_NOPE
|
||||
ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
|
||||
|
||||
y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
|
||||
y_s = tl.max(tl.abs(y)) / FP8_MAX
|
||||
y_s_inv = 1.0 / y_s
|
||||
y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
|
||||
output_nope_q_ptr.dtype.element_ty
|
||||
)
|
||||
|
||||
dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
|
||||
dst_s_ptr = (
|
||||
output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
|
||||
)
|
||||
|
||||
tl.store(dst_q_ptr, y_q, mask=mask)
|
||||
tl.store(dst_s_ptr, y_s)
|
||||
else:
|
||||
# b. copy rope
|
||||
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
|
||||
|
||||
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs < DIM_ROPE
|
||||
|
||||
src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
|
||||
dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
|
||||
|
||||
data = tl.load(src_ptr, mask=mask)
|
||||
tl.store(dst_ptr, data, mask=mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for num_blocks, block_size in [
|
||||
(1, 1),
|
||||
(10, 64),
|
||||
]:
|
||||
dim_nope_and_rope = 512 + 64
|
||||
|
||||
input_k_cache = torch.randn(
|
||||
(num_blocks, block_size, 1, dim_nope_and_rope),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
# temp debug
|
||||
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
|
||||
|
||||
ref_quant = _quantize_k_cache_slow(input_k_cache)
|
||||
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
|
||||
# print(f"{input_k_cache=}")
|
||||
# print(f"{ref_quant=}")
|
||||
# print(f"{actual_quant=}")
|
||||
# print(f"{ref_quant == actual_quant=}")
|
||||
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
|
||||
# print(f"{ref_quant.view(torch.bfloat16)=}")
|
||||
# print(f"{actual_quant.view(torch.bfloat16)=}")
|
||||
# assert torch.all(ref_quant == actual_quant)
|
||||
|
||||
import dequant_k_cache
|
||||
|
||||
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
|
||||
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
|
||||
actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
|
||||
actual_quant
|
||||
)
|
||||
|
||||
print(f"{ref_ref_dequant=}")
|
||||
print(f"{actual_actual_dequant=}")
|
||||
print(f"{actual_actual_dequant - ref_ref_dequant=}")
|
||||
print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
|
||||
|
||||
# TODO too different?
|
||||
torch.testing.assert_close(
|
||||
ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
|
||||
)
|
||||
|
||||
print("Passed")
|
||||
785
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
Normal file
785
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
Normal file
@@ -0,0 +1,785 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
tilelang.set_log_level("WARNING")
|
||||
|
||||
pass_configs = {
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
|
||||
}
|
||||
|
||||
BF16 = "bfloat16"
|
||||
FP8 = "float8_e4m3"
|
||||
FP32 = "float32"
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
def fast_log2_ceil(x):
|
||||
bits_x = T.reinterpret("uint32", x)
|
||||
exp_x = (bits_x >> 23) & 0xFF
|
||||
man_bits = bits_x & ((1 << 23) - 1)
|
||||
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
||||
|
||||
|
||||
def fast_pow2(x):
|
||||
bits_x = (x + 127) << 23
|
||||
return T.reinterpret("float32", bits_x)
|
||||
|
||||
|
||||
def fast_round_scale(amax, fp8_max_inv):
|
||||
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def act_quant_kernel(
|
||||
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
|
||||
):
|
||||
M = T.symbolic("M")
|
||||
fp8_min = -448.0
|
||||
fp8_max = 448.0
|
||||
fp8_max_inv = 1 / fp8_max
|
||||
num_stages = 0 if round_scale else 2
|
||||
blk_m = 32
|
||||
group_size = 128
|
||||
|
||||
@T.prim_func
|
||||
def act_quant_kernel_(
|
||||
X: T.Tensor[(M, N), in_dtype],
|
||||
Y: T.Tensor[(M, N), out_dtype],
|
||||
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
||||
):
|
||||
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
||||
pid_m,
|
||||
pid_n,
|
||||
):
|
||||
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
||||
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
||||
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
|
||||
s_local = T.alloc_fragment((blk_m,), scale_dtype)
|
||||
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
||||
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
||||
|
||||
for _ in T.Pipelined(1, num_stages=num_stages):
|
||||
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
||||
T.copy(x_shared, x_local)
|
||||
T.reduce_absmax(x_local, amax_local, dim=1)
|
||||
for i in T.Parallel(blk_m):
|
||||
amax_local[i] = T.max(amax_local[i], 1e-4)
|
||||
if round_scale:
|
||||
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
||||
else:
|
||||
s_local[i] = amax_local[i] * fp8_max_inv
|
||||
for i, j in T.Parallel(blk_m, group_size):
|
||||
y_local[i, j] = T.clamp(
|
||||
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
||||
)
|
||||
for i in T.Parallel(blk_m):
|
||||
S[pid_m * blk_m + i, pid_n] = s_local[i]
|
||||
T.copy(y_local, y_shared)
|
||||
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
||||
|
||||
return act_quant_kernel_
|
||||
|
||||
|
||||
def act_quant(
|
||||
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes the input tensor `x` using block-wise quantization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||
- A tensor of scaling factors with dtype `torch.float32`.
|
||||
"""
|
||||
assert x.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert (
|
||||
x.size(-1) % block_size == 0
|
||||
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
|
||||
N = x.size(-1)
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
|
||||
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
|
||||
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
||||
return y, s
|
||||
|
||||
|
||||
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
|
||||
def fp8_index_kernel(h: int, d: int, clear_accum=True):
|
||||
b = T.symbolic("b")
|
||||
m = T.symbolic("m")
|
||||
n = T.symbolic("n")
|
||||
|
||||
blk_n1 = 512
|
||||
blk_n2 = 128
|
||||
|
||||
@T.prim_func
|
||||
def fp8_index_kernel_(
|
||||
q: T.Tensor[(b, m, h, d), FP8],
|
||||
q_s: T.Tensor[(b, m, h), FP32],
|
||||
k: T.Tensor[(b, n, d), FP8],
|
||||
k_s: T.Tensor[(b, n), FP32],
|
||||
o: T.Tensor[(b, m, n), FP32],
|
||||
) -> None:
|
||||
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
|
||||
q_smem = T.alloc_shared((h, d), FP8)
|
||||
T.copy(q[i_b, i_m, 0, 0], q_smem)
|
||||
|
||||
q_s_frag = T.alloc_fragment(h, FP32)
|
||||
T.copy(q_s[i_b, i_m, 0], q_s_frag)
|
||||
|
||||
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
|
||||
k_smem = T.alloc_shared((blk_n2, d), FP8)
|
||||
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
|
||||
|
||||
k_s_frag = T.alloc_fragment(blk_n2, FP32)
|
||||
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
|
||||
|
||||
logits = T.alloc_fragment((blk_n2, h), FP32)
|
||||
T.gemm(
|
||||
k_smem,
|
||||
q_smem,
|
||||
logits,
|
||||
transpose_A=False,
|
||||
transpose_B=True,
|
||||
clear_accum=clear_accum,
|
||||
)
|
||||
|
||||
for i_h, i3_n in T.Parallel(h, blk_n2):
|
||||
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
|
||||
|
||||
logits_sum = T.alloc_fragment(blk_n2, FP32)
|
||||
T.reduce_sum(logits, logits_sum, dim=1)
|
||||
|
||||
for i3_n in T.Parallel(blk_n2):
|
||||
logits_sum[i3_n] *= k_s_frag[i3_n]
|
||||
|
||||
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
|
||||
|
||||
return fp8_index_kernel_
|
||||
|
||||
|
||||
def fp8_index(
|
||||
q: torch.Tensor,
|
||||
q_s: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
k_s: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform index score using FP8 precision.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): The Q tensor, must be contiguous.
|
||||
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
|
||||
k (torch.Tensor): The K tensor, must be contiguous.
|
||||
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
|
||||
|
||||
fp8 q @ fp8 k -> fp32 logits
|
||||
relu(fp32 logits) * q_s (weights) -> fp32 logits
|
||||
fp32 logits -> fp32 logits_sum
|
||||
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
|
||||
"""
|
||||
if _is_hip:
|
||||
return fp8_index_kernel(q.shape[2], q.shape[3], False)(q, q_s, k, k_s)
|
||||
else:
|
||||
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
out_idx=[-1],
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
},
|
||||
)
|
||||
def sparse_attention_fwd_kernel_v1(
|
||||
num_heads,
|
||||
dim,
|
||||
tail_dim,
|
||||
topk,
|
||||
*,
|
||||
kv_group=1,
|
||||
sm_scale=None,
|
||||
is_causal=True,
|
||||
block_I=64,
|
||||
num_stages=2,
|
||||
threads=256,
|
||||
):
|
||||
assert dim == tilelang.math.next_power_of_2(
|
||||
dim
|
||||
), f"haven't check padding correctness yet, dim={dim}"
|
||||
assert tail_dim == tilelang.math.next_power_of_2(
|
||||
tail_dim
|
||||
), f"haven't check padding correctness yet, dim={tail_dim}"
|
||||
assert is_causal == True, "non-casual is not supported"
|
||||
assert (
|
||||
topk % block_I == 0
|
||||
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
|
||||
if sm_scale is None:
|
||||
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
|
||||
else:
|
||||
sm_scale = sm_scale * 1.44269504 # log2(e)
|
||||
|
||||
batch = T.symbolic("batch")
|
||||
seq_len = T.symbolic("seq_len")
|
||||
seq_len_kv = T.symbolic("seq_len_kv")
|
||||
|
||||
head_kv = num_heads // kv_group
|
||||
q_shape = [batch, seq_len, num_heads, dim + tail_dim]
|
||||
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
|
||||
o_shape = [batch, seq_len, num_heads, dim]
|
||||
indices_shape = [batch, seq_len, kv_group, topk]
|
||||
indices_dtype = "int32"
|
||||
dtype = "bfloat16"
|
||||
accum_dtype = "float"
|
||||
|
||||
H = head_kv
|
||||
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
|
||||
if padded_H != H:
|
||||
assert kv_group == 1
|
||||
BI = block_I
|
||||
NI = tilelang.cdiv(topk, block_I)
|
||||
D = dim
|
||||
D_tail = tail_dim
|
||||
|
||||
if head_kv > 64:
|
||||
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
|
||||
REPLICATE_H = head_kv // 64
|
||||
else:
|
||||
REPLICATE_H = 1
|
||||
|
||||
H_per_block = padded_H if REPLICATE_H == 1 else 64
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
Q: T.Tensor(q_shape, dtype), # type: ignore
|
||||
KV: T.Tensor(kv_shape, dtype), # type: ignore
|
||||
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
|
||||
Output: T.Tensor(o_shape, dtype), # type: ignore
|
||||
):
|
||||
with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
|
||||
bx,
|
||||
by,
|
||||
bz,
|
||||
):
|
||||
Q_shared = T.alloc_shared([H_per_block, D], dtype)
|
||||
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
|
||||
KV_shared = T.alloc_shared([BI, D], dtype)
|
||||
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
|
||||
O_shared = T.alloc_shared([H_per_block, D], dtype)
|
||||
mask = T.alloc_fragment([BI], "bool")
|
||||
|
||||
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
|
||||
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
|
||||
S_shared = T.alloc_shared([H_per_block, BI], dtype)
|
||||
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
alpha = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
|
||||
T.fill(acc_o, 0)
|
||||
T.fill(sumexp, 0)
|
||||
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
|
||||
|
||||
b_i, g_i = by, bz
|
||||
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
|
||||
q_i = s_i
|
||||
max_kv_i = q_i
|
||||
|
||||
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
|
||||
H1 = H0 + H_per_block
|
||||
|
||||
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
|
||||
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
|
||||
|
||||
for i_i in T.Pipelined(NI, num_stages=num_stages):
|
||||
|
||||
for bi_i in T.Parallel(BI):
|
||||
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0
|
||||
|
||||
for bi_i, d_i in T.Parallel(BI, D):
|
||||
KV_shared[bi_i, d_i] = KV[
|
||||
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i
|
||||
]
|
||||
for bi_i, d_i in T.Parallel(BI, D_tail):
|
||||
K_tail_shared[bi_i, d_i] = KV[
|
||||
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i
|
||||
]
|
||||
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.if_then_else(
|
||||
mask[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared,
|
||||
KV_shared,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
policy=T.GemmWarpPolicy.FullCol,
|
||||
)
|
||||
T.gemm(
|
||||
Q_tail_shared,
|
||||
K_tail_shared,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
policy=T.GemmWarpPolicy.FullCol,
|
||||
)
|
||||
T.copy(m_i, m_i_prev)
|
||||
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.exp2(
|
||||
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||
)
|
||||
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
|
||||
for h_i, d_i in T.Parallel(H_per_block, D):
|
||||
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
|
||||
|
||||
T.copy(acc_s, S_shared)
|
||||
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
|
||||
|
||||
# Rescale
|
||||
for h_i, d_i in T.Parallel(H_per_block, D):
|
||||
acc_o[h_i, d_i] /= sumexp[h_i]
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
|
||||
|
||||
T.copy(acc_o, O_shared)
|
||||
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
|
||||
|
||||
return main
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
out_idx=[-1],
|
||||
compile_flags=[
|
||||
"-O3",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--ptxas-options=-v,--register-usage-level=10",
|
||||
"-DNDEBUG",
|
||||
],
|
||||
) # type: ignore
|
||||
def sparse_attention_fwd_kernel_v2(
|
||||
num_heads: int,
|
||||
dim: int,
|
||||
tail_dim: int,
|
||||
topk: int,
|
||||
*,
|
||||
kv_group: int = 1,
|
||||
sm_scale: Optional[float] = None,
|
||||
block_I: int = 64,
|
||||
):
|
||||
assert dim == tilelang.math.next_power_of_2(
|
||||
dim
|
||||
), f"haven't check padding correctness yet, dim={dim}"
|
||||
assert tail_dim == tilelang.math.next_power_of_2(
|
||||
tail_dim
|
||||
), f"haven't check padding correctness yet, dim={tail_dim}"
|
||||
assert (
|
||||
topk % block_I == 0
|
||||
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
|
||||
if sm_scale is None:
|
||||
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
|
||||
else:
|
||||
sm_scale = sm_scale * 1.44269504 # log2(e)
|
||||
threads = 384
|
||||
|
||||
batch = T.symbolic("batch")
|
||||
qo_len = T.symbolic("seq_len")
|
||||
num_pages = T.symbolic("num_pages")
|
||||
|
||||
q_shape = [batch, qo_len, num_heads, dim + tail_dim]
|
||||
kv_shape = [batch, num_pages, kv_group, dim + tail_dim]
|
||||
o_shape = [batch, qo_len, num_heads, dim]
|
||||
indices_shape = [batch, qo_len, kv_group, topk]
|
||||
|
||||
indices_dtype = "int32"
|
||||
dtype = "bfloat16"
|
||||
accum_dtype = "float"
|
||||
|
||||
H = num_heads
|
||||
padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)
|
||||
if padded_H != H:
|
||||
assert kv_group == 1
|
||||
BI = block_I
|
||||
NI = tilelang.cdiv(topk, block_I)
|
||||
assert NI % 2 == 0, "NI should be a multiple of 2"
|
||||
D = dim
|
||||
D_tail = tail_dim
|
||||
if num_heads > 64:
|
||||
assert num_heads % 64 == 0, "head_kv should be a multiple of 64"
|
||||
REPLICATE_H = num_heads // 64
|
||||
else:
|
||||
REPLICATE_H = 1
|
||||
|
||||
H_per_block = padded_H if REPLICATE_H == 1 else 64
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
Q: T.Tensor(q_shape, dtype), # type: ignore
|
||||
KV: T.Tensor(kv_shape, dtype), # type: ignore
|
||||
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
|
||||
Output: T.Tensor(o_shape, dtype), # type: ignore
|
||||
):
|
||||
"""
|
||||
Q: [b, qo_len, H, D + D_tail] (bfloat16)
|
||||
KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
|
||||
Indices: [b, qo_len, kv_group, topk] (int32)
|
||||
"""
|
||||
|
||||
with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore
|
||||
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
|
||||
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
|
||||
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
|
||||
KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
|
||||
KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
|
||||
KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
|
||||
KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
|
||||
K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
|
||||
K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
|
||||
O_shared_l = Q_shared_l
|
||||
O_shared_r = Q_shared_r
|
||||
is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared")
|
||||
is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared")
|
||||
|
||||
acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
|
||||
acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
|
||||
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
|
||||
S_shared = T.alloc_shared([H_per_block, BI], dtype)
|
||||
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
|
||||
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
|
||||
alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
indices_local = T.alloc_local([1], indices_dtype)
|
||||
indices_tmp = T.alloc_local([1], indices_dtype)
|
||||
|
||||
bar_q = T.alloc_barrier(arrive_count=384)
|
||||
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
|
||||
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
|
||||
bar_k_0_free = T.alloc_barrier(arrive_count=256)
|
||||
bar_k_1_free = T.alloc_barrier(arrive_count=256)
|
||||
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
|
||||
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
|
||||
|
||||
bar_0_128 = T.alloc_barrier(arrive_count=128)
|
||||
bar_1_128 = T.alloc_barrier(arrive_count=128)
|
||||
bar_2_128 = T.alloc_barrier(arrive_count=128)
|
||||
bar_final = T.alloc_barrier(arrive_count=128)
|
||||
|
||||
b_i, g_i = by, bz
|
||||
s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H
|
||||
|
||||
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
|
||||
H1 = H0 + H_per_block
|
||||
|
||||
tx = T.get_thread_binding()
|
||||
|
||||
T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
|
||||
T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
|
||||
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
|
||||
T.barrier_arrive(bar_q)
|
||||
|
||||
if tx < 128:
|
||||
T.set_max_nreg(240, 1)
|
||||
T.fill(sumexp, 0)
|
||||
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
|
||||
T.fill(acc_o_l, 0)
|
||||
T.barrier_wait(bar_q, 0)
|
||||
|
||||
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||
# Buffer 0
|
||||
# with sync_at(bar_0_128, 0):
|
||||
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
|
||||
T.barrier_arrive(bar_0_128)
|
||||
T.barrier_wait(bar_0_128, 0)
|
||||
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.if_then_else(
|
||||
is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_tail_shared,
|
||||
K_tail_shared_0,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
wg_wait=-1,
|
||||
)
|
||||
|
||||
T.wait_wgmma(0)
|
||||
|
||||
if i_i != 0:
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
|
||||
|
||||
T.copy(m_i, m_i_prev)
|
||||
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.exp2(
|
||||
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||
)
|
||||
T.reduce_sum(
|
||||
acc_s, sumexp_i, dim=1
|
||||
) # is this a accumulate operator?
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_l[h_i, d_i] *= alpha_local[h_i]
|
||||
T.copy(alpha_local, alpha_shared)
|
||||
|
||||
T.copy(acc_s, S_shared)
|
||||
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
|
||||
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_arrive(bar_k_0_free[0])
|
||||
|
||||
# Buffer 1
|
||||
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
|
||||
T.barrier_arrive(bar_0_128)
|
||||
T.barrier_wait(bar_0_128, 1)
|
||||
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.if_then_else(
|
||||
is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_tail_shared,
|
||||
K_tail_shared_1,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
wg_wait=-1,
|
||||
)
|
||||
|
||||
T.wait_wgmma(0)
|
||||
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
|
||||
|
||||
T.copy(m_i, m_i_prev)
|
||||
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.exp2(
|
||||
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||
)
|
||||
T.reduce_sum(
|
||||
acc_s, sumexp_i, dim=1
|
||||
) # is this a accumulate operator?
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_l[h_i, d_i] *= alpha_local[h_i]
|
||||
T.copy(alpha_local, alpha_shared)
|
||||
|
||||
T.copy(acc_s, S_shared)
|
||||
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
|
||||
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_arrive(bar_k_1_free[0])
|
||||
|
||||
# Rescale
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sum_exp_shared[h_i] = sumexp[h_i]
|
||||
T.barrier_arrive(bar_final)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_l[h_i, d_i] /= sumexp[h_i]
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
|
||||
T.copy(acc_o_l, O_shared_l)
|
||||
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
|
||||
elif tx >= 128 and tx < 256:
|
||||
# T.set_max_nreg(168, 1)
|
||||
T.fill(acc_o_r, 0)
|
||||
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||
# Buffer 0
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
|
||||
T.barrier_arrive(bar_1_128)
|
||||
T.barrier_wait(bar_1_128, 0)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
|
||||
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
|
||||
T.barrier_arrive(bar_k_0_free[0])
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
|
||||
# Buffer 1
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
|
||||
T.barrier_arrive(bar_1_128)
|
||||
T.barrier_wait(bar_1_128, 1)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
|
||||
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
|
||||
T.barrier_arrive(bar_k_1_free[0])
|
||||
if i_i != T.ceildiv(NI, 2) - 1:
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
|
||||
# Rescale
|
||||
T.barrier_wait(bar_final, 0)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
|
||||
|
||||
T.copy(acc_o_r, O_shared_r)
|
||||
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
|
||||
elif tx >= 256:
|
||||
# producer
|
||||
T.set_max_nreg(80, 0)
|
||||
indices_local[0] = 0
|
||||
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||
# Buffer 0
|
||||
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
|
||||
T.barrier_arrive(bar_2_128)
|
||||
T.barrier_wait(bar_2_128, 0)
|
||||
|
||||
for r in T.serial(4):
|
||||
indices_tmp[0] = Indices[
|
||||
b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8
|
||||
]
|
||||
is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
|
||||
if is_kv_valid_0[r * 16 + (tx - 256) // 8]:
|
||||
indices_local[0] = indices_tmp[0]
|
||||
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for u in T.serial(4):
|
||||
for v in T.vectorized(8):
|
||||
KV_shared_0_l[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
KV_shared_0_r[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for v in T.vectorized(8):
|
||||
K_tail_shared_0[
|
||||
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
|
||||
T.cp_async_barrier_noinc(bar_k_0_ready[0])
|
||||
|
||||
# Buffer 1
|
||||
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
|
||||
T.barrier_arrive(bar_2_128)
|
||||
T.barrier_wait(bar_2_128, 1)
|
||||
|
||||
for r in T.serial(4):
|
||||
indices_tmp[0] = Indices[
|
||||
b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8
|
||||
]
|
||||
is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
|
||||
if is_kv_valid_1[r * 16 + (tx - 256) // 8]:
|
||||
indices_local[0] = indices_tmp[0]
|
||||
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for u in T.serial(4):
|
||||
for v in T.vectorized(8):
|
||||
KV_shared_1_l[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
KV_shared_1_r[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for v in T.vectorized(8):
|
||||
K_tail_shared_1[
|
||||
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
|
||||
T.cp_async_barrier_noinc(bar_k_1_ready[0])
|
||||
|
||||
return main
|
||||
|
||||
|
||||
def tilelang_sparse_fwd(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> torch.Tensor:
|
||||
assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3
|
||||
num_heads = q.shape[1]
|
||||
dim = q.shape[2]
|
||||
tail_dim = dim - d_v
|
||||
topk = indices.shape[-1]
|
||||
assert topk == 2048
|
||||
if _is_hip:
|
||||
kernel = sparse_attention_fwd_kernel_v1(
|
||||
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale, num_stages=1
|
||||
)
|
||||
else:
|
||||
kernel = sparse_attention_fwd_kernel_v2(
|
||||
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
|
||||
)
|
||||
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
|
||||
144
python/sglang/srt/layers/attention/nsa/transform_index.py
Normal file
144
python/sglang/srt/layers/attention/nsa/transform_index.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def transform_index_page_table_prefill(**kwargs):
|
||||
return transform_index_page_table_prefill_ref(**kwargs)
|
||||
|
||||
|
||||
def transform_index_page_table_decode(**kwargs):
|
||||
return transform_index_page_table_decode_ref(**kwargs)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def transform_index_page_table_decode_kernel(
|
||||
page_table_ptr: torch.Tensor,
|
||||
topk_indices_ptr: torch.Tensor,
|
||||
result_ptr: torch.Tensor,
|
||||
page_size: tl.constexpr,
|
||||
max_seqlen_k: tl.constexpr,
|
||||
):
|
||||
TOPK: tl.constexpr = 2048
|
||||
req_id = tl.program_id(0)
|
||||
page_table_ptr = page_table_ptr + req_id * max_seqlen_k
|
||||
topk_indices_ptr = topk_indices_ptr + req_id * TOPK
|
||||
result_ptr = result_ptr + req_id * TOPK
|
||||
|
||||
offset = tl.arange(0, TOPK) # topk should be 2048
|
||||
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
|
||||
mask = loaded_topk_indices >= 0
|
||||
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
|
||||
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
|
||||
tl.store(result_ptr + offset, -1, mask=~mask)
|
||||
|
||||
|
||||
def transform_index_page_table_decode_fast(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
result: Optional[torch.Tensor] = None,
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform the page table according to topk indices for sparse topk attention.
|
||||
Args:
|
||||
page_table: [qo_len, max_seqlen_k], the original page table
|
||||
topk_indices: [qo_len, topk], the topk indices for each query position
|
||||
Returns:
|
||||
transformed_page_table: [qo_len, topk], the transformed page table
|
||||
For out-of-bound indices in topk_indices, this should be filled with -1.
|
||||
"""
|
||||
assert page_size == 1
|
||||
assert page_table.shape[0] == topk_indices.shape[0]
|
||||
assert topk_indices.shape[1] == 2048
|
||||
qo_len = topk_indices.shape[0]
|
||||
max_seqlen_k = page_table.shape[1]
|
||||
if result is None:
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
# Launch triton kernel
|
||||
grid = (qo_len,)
|
||||
transform_index_page_table_decode_kernel[grid](
|
||||
page_table,
|
||||
topk_indices,
|
||||
result,
|
||||
page_size,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def transform_index_page_table_prefill_fast(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
extend_lens_cpu: List[int],
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
# TODO(baizhou): can be implemented with another triton kernel
|
||||
assert page_size == 1
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
assert len(extend_lens_cpu) == page_table.shape[0]
|
||||
offset = 0
|
||||
for i, l in enumerate(extend_lens_cpu):
|
||||
transform_index_page_table_decode_fast(
|
||||
page_table[i].unsqueeze(0).expand(l, -1),
|
||||
topk_indices[offset : offset + l],
|
||||
result=result[offset : offset + l],
|
||||
)
|
||||
offset += l
|
||||
assert offset == topk_indices.shape[0]
|
||||
return result
|
||||
|
||||
|
||||
def transform_index_page_table_decode_ref(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
result: Optional[torch.Tensor] = None,
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
assert page_size == 1
|
||||
assert page_table.shape[0] == topk_indices.shape[0]
|
||||
if result is None:
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
assert result.shape == topk_indices.shape
|
||||
torch.gather(
|
||||
page_table,
|
||||
dim=1,
|
||||
index=topk_indices.clamp(min=0),
|
||||
out=result,
|
||||
)
|
||||
result[topk_indices < 0] = -1
|
||||
return result
|
||||
|
||||
|
||||
def transform_index_page_table_prefill_ref(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
extend_lens_cpu: List[int],
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
assert page_size == 1
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
assert len(extend_lens_cpu) == page_table.shape[0]
|
||||
offset = 0
|
||||
for i, l in enumerate(extend_lens_cpu):
|
||||
transform_index_page_table_decode_ref(
|
||||
page_table[i].unsqueeze(0).expand(l, -1),
|
||||
topk_indices[offset : offset + l],
|
||||
result=result[offset : offset + l],
|
||||
)
|
||||
offset += l
|
||||
assert offset == topk_indices.shape[0]
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bs, topk, max_seqlen = 10, 2048, 3000
|
||||
page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
|
||||
topk_indices = torch.full((bs, topk), -1, device="cuda")
|
||||
topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
|
||||
ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
|
||||
result = transform_index_page_table_decode_fast(page_table, topk_indices)
|
||||
assert torch.all(result == ref_result)
|
||||
print("Passed")
|
||||
24
python/sglang/srt/layers/attention/nsa/utils.py
Normal file
24
python/sglang/srt/layers/attention/nsa/utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# temp NSA debugging environ
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
|
||||
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
|
||||
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
|
||||
|
||||
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
|
||||
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
|
||||
)
|
||||
NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "true")
|
||||
NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "true")
|
||||
|
||||
|
||||
def print_nsa_bool_env_vars():
|
||||
msg = ""
|
||||
for k, v in globals().items():
|
||||
if k.startswith("NSA_") and isinstance(v, bool):
|
||||
msg += f"{k}={v} "
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
|
||||
return original_seq_lens.clamp(max=nsa_index_topk)
|
||||
887
python/sglang/srt/layers/attention/nsa_backend.py
Normal file
887
python/sglang/srt/layers/attention/nsa_backend.py
Normal file
@@ -0,0 +1,887 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||
from sglang.srt.layers.attention.nsa.transform_index import (
|
||||
transform_index_page_table_decode,
|
||||
transform_index_page_table_prefill,
|
||||
)
|
||||
from sglang.srt.layers.attention.nsa.utils import (
|
||||
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
NSA_FUSE_TOPK,
|
||||
compute_nsa_seqlens,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_hip:
|
||||
try:
|
||||
from aiter import (
|
||||
flash_attn_varlen_func,
|
||||
mha_batch_prefill_func,
|
||||
paged_attention_ragged,
|
||||
)
|
||||
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
||||
except ImportError:
|
||||
print(
|
||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||
)
|
||||
else:
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NSAFlashMLAMetadata:
|
||||
"""Metadata only needed by FlashMLA"""
|
||||
|
||||
flashmla_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
def slice(self, sli):
|
||||
return NSAFlashMLAMetadata(
|
||||
flashmla_metadata=self.flashmla_metadata,
|
||||
num_splits=self.num_splits[sli],
|
||||
)
|
||||
|
||||
def copy_(self, other: "NSAFlashMLAMetadata"):
|
||||
self.flashmla_metadata.copy_(other.flashmla_metadata)
|
||||
self.num_splits.copy_(other.num_splits)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NSAMetadata:
|
||||
page_size: int
|
||||
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor
|
||||
# Maximum sequence length for query
|
||||
max_seq_len_q: int
|
||||
# Maximum sequence length for key
|
||||
max_seq_len_k: int
|
||||
# Cumulative sequence lengths for query
|
||||
cu_seqlens_q: torch.Tensor
|
||||
# Cumulative sequence lengths for key
|
||||
cu_seqlens_k: torch.Tensor
|
||||
# Page table, the index of KV Cache Tables/Blocks
|
||||
# this table is always with page_size = 1
|
||||
page_table_1: torch.Tensor
|
||||
|
||||
# NOTE(dark): This will property be used in:
|
||||
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
|
||||
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
|
||||
real_page_table: torch.Tensor
|
||||
|
||||
# NSA metadata (nsa prefill are expanded)
|
||||
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
|
||||
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
|
||||
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
|
||||
nsa_extend_seq_lens_list: List[int]
|
||||
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
|
||||
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
|
||||
|
||||
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NSAIndexerMetadata(BaseIndexerMetadata):
|
||||
attn_metadata: NSAMetadata
|
||||
|
||||
def get_seqlens_int32(self) -> torch.Tensor:
|
||||
return self.attn_metadata.cache_seqlens_int32
|
||||
|
||||
def get_page_table_64(self) -> torch.Tensor:
|
||||
return self.attn_metadata.real_page_table
|
||||
|
||||
def get_seqlens_expanded(self) -> torch.Tensor:
|
||||
return self.attn_metadata.nsa_seqlens_expanded
|
||||
|
||||
def topk_transform(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
|
||||
|
||||
if not NSA_FUSE_TOPK:
|
||||
return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
|
||||
|
||||
# NOTE(dark): if fused, we return a transformed page table directly
|
||||
return fast_topk_transform_fused(
|
||||
score=logits,
|
||||
lengths=self.get_seqlens_expanded(),
|
||||
page_table_size_1=self.attn_metadata.page_table_1,
|
||||
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
|
||||
topk=topk,
|
||||
)
|
||||
|
||||
|
||||
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
|
||||
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
|
||||
return torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
|
||||
|
||||
_NSA_IMPL_T: TypeAlias = Literal[
|
||||
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
|
||||
]
|
||||
|
||||
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
||||
NSA_DECODE_IMPL: _NSA_IMPL_T
|
||||
|
||||
|
||||
class NativeSparseAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.forward_metadata: NSAMetadata
|
||||
self.device = model_runner.device
|
||||
assert isinstance(model_runner.page_size, int)
|
||||
self.real_page_size = model_runner.page_size
|
||||
self.num_splits = (
|
||||
1 if model_runner.server_args.enable_deterministic_inference else 0
|
||||
)
|
||||
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
|
||||
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
|
||||
self.nsa_kv_cache_store_fp8 = (
|
||||
model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
|
||||
)
|
||||
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.num_q_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
|
||||
|
||||
assert model_runner.req_to_token_pool is not None
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
||||
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
|
||||
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
|
||||
|
||||
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
||||
|
||||
if _is_hip:
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
def get_device_int32_arange(self, l: int) -> torch.Tensor:
|
||||
if l > len(self._arange_buf):
|
||||
next_pow_of_2 = 1 << (l - 1).bit_length()
|
||||
self._arange_buf = torch.arange(
|
||||
next_pow_of_2, device=self.device, dtype=torch.int32
|
||||
)
|
||||
return self._arange_buf[:l]
|
||||
|
||||
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
|
||||
page_size = self.real_page_size
|
||||
if page_size == 1:
|
||||
return page_table
|
||||
max_seqlen_k = page_table.shape[1]
|
||||
strided_indices = torch.arange(
|
||||
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
|
||||
)
|
||||
return page_table[:, strided_indices] // page_size
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
batch_size = forward_batch.batch_size
|
||||
device = forward_batch.seq_lens.device
|
||||
|
||||
assert (
|
||||
forward_batch.spec_info is None
|
||||
), "Spec decoding is not supported for NSA backend now"
|
||||
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
|
||||
page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, :max_seqlen_k
|
||||
]
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens_cpu = [1] * batch_size
|
||||
max_seqlen_q = 1
|
||||
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
|
||||
seqlens_expanded = cache_seqlens_int32
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
assert (
|
||||
forward_batch.extend_seq_lens_cpu is not None
|
||||
and forward_batch.extend_seq_lens is not None
|
||||
and forward_batch.extend_prefix_lens_cpu is not None
|
||||
), "All of them must not be None"
|
||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
|
||||
assert forward_batch.extend_seq_lens is not None
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
max_seqlen_q = max(extend_seq_lens_cpu)
|
||||
cu_seqlens_q = compute_cu_seqlens(
|
||||
forward_batch.extend_seq_lens.to(torch.int32)
|
||||
)
|
||||
else:
|
||||
max_seqlen_q = max_seqlen_k
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
seqlens_expanded = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
kv_len - qo_len + 1,
|
||||
kv_len + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
for qo_len, kv_len in zip(
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
forward_batch.seq_lens_cpu.tolist(),
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported {forward_batch.forward_mode = }"
|
||||
|
||||
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
original_seq_lens=seqlens_expanded,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
||||
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
||||
|
||||
metadata = NSAMetadata(
|
||||
page_size=self.real_page_size,
|
||||
cache_seqlens_int32=cache_seqlens_int32,
|
||||
max_seq_len_q=max_seqlen_q,
|
||||
max_seq_len_k=max_seqlen_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
page_table_1=page_table,
|
||||
flashmla_metadata=(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
else None
|
||||
),
|
||||
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
||||
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
||||
nsa_seqlens_expanded=seqlens_expanded,
|
||||
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
|
||||
real_page_table=self._transform_table_1_to_real(page_table),
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
"""Initialize CUDA graph state for the attention backend.
|
||||
|
||||
Args:
|
||||
max_bs (int): Maximum batch size to support in CUDA graphs
|
||||
|
||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
self.decode_cuda_graph_metadata: Dict = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
# fake page_table for sparse_prefill
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
self.max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"flashmla_metadata": (
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=torch.ones(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
"""Initialize forward metadata for capturing CUDA graph."""
|
||||
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||
assert (
|
||||
spec_info is None
|
||||
), "Speculative decoding is not supported for NSA backend now"
|
||||
|
||||
# Normal Decode
|
||||
# Get sequence information
|
||||
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
|
||||
# Use max context length for seq_len_k
|
||||
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
max_seq_len_k = page_table_1.shape[1]
|
||||
|
||||
# Precompute page table
|
||||
# Precompute cumulative sequence lengths
|
||||
|
||||
# NOTE(dark): this is always arange, since we are decoding
|
||||
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
||||
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
||||
real_page_table = self._transform_table_1_to_real(page_table_1)
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, bs + 1))
|
||||
flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
)
|
||||
else:
|
||||
flashmla_metadata = None
|
||||
|
||||
metadata = NSAMetadata(
|
||||
page_size=self.real_page_size,
|
||||
cache_seqlens_int32=cache_seqlens_int32,
|
||||
max_seq_len_q=1,
|
||||
max_seq_len_k=max_seq_len_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
page_table_1=page_table_1,
|
||||
flashmla_metadata=flashmla_metadata,
|
||||
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
||||
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
||||
nsa_seqlens_expanded=cache_seqlens_int32,
|
||||
real_page_table=real_page_table,
|
||||
nsa_extend_seq_lens_list=[1] * bs,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
out_cache_loc: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize forward metadata for replaying CUDA graph."""
|
||||
assert seq_lens_cpu is not None
|
||||
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||
assert (
|
||||
spec_info is None
|
||||
), "Speculative decoding is not supported for NSA backend now"
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
|
||||
# Normal Decode
|
||||
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = int(seq_lens_cpu.max().item())
|
||||
|
||||
cache_seqlens = seq_lens.to(torch.int32)
|
||||
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
page_indices = self.req_to_token[req_pool_indices, :max_len]
|
||||
metadata.page_table_1[:, :max_len].copy_(page_indices)
|
||||
assert (
|
||||
metadata.nsa_cache_seqlens_int32 is not None
|
||||
and metadata.nsa_cu_seqlens_k is not None
|
||||
and self.nsa_index_topk is not None
|
||||
)
|
||||
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
|
||||
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
||||
metadata.nsa_cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
|
||||
|
||||
assert self.real_page_size == metadata.page_size
|
||||
if self.real_page_size > 1:
|
||||
real_table = self._transform_table_1_to_real(page_indices)
|
||||
new_len = real_table.shape[1]
|
||||
metadata.real_page_table[:, :new_len].copy_(real_table)
|
||||
else:
|
||||
assert metadata.real_page_table is metadata.page_table_1
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
metadata.flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
), "NSA backend doesn't support speculative decoding"
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
metadata = self.forward_metadata
|
||||
causal = not layer.is_cross_attention
|
||||
assert causal, "NSA is causal only"
|
||||
|
||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||
kwargs = {}
|
||||
|
||||
# Do absorbed multi-latent attention
|
||||
assert q_rope is not None
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
|
||||
# when store in fp8 and compute in fp8, no need to convert dtype
|
||||
if not (
|
||||
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and self.nsa_kv_cache_store_fp8
|
||||
):
|
||||
kv_cache = kv_cache.to(q.dtype)
|
||||
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
|
||||
# NOTE(dark): here, we use page size = 1
|
||||
|
||||
if NSA_FUSE_TOPK:
|
||||
page_table_1 = topk_indices
|
||||
else:
|
||||
assert metadata.nsa_extend_seq_lens_list is not None
|
||||
page_table_1 = transform_index_page_table_prefill(
|
||||
page_table=metadata.page_table_1,
|
||||
topk_indices=topk_indices,
|
||||
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
|
||||
page_size=1,
|
||||
)
|
||||
if NSA_PREFILL_IMPL == "tilelang":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_tilelang(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_prefill(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "flashmla_decode":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_decode(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
# TODO optimize args
|
||||
layer=layer,
|
||||
metadata=metadata,
|
||||
page_table_1=page_table_1,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "fa3":
|
||||
return self._forward_fa3(
|
||||
q_rope=q_rope,
|
||||
kv_cache=kv_cache,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
q_nope=q_nope,
|
||||
page_table=page_table_1,
|
||||
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
||||
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
||||
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
||||
sm_scale=layer.scaling,
|
||||
logit_cap=layer.logit_cap,
|
||||
page_size=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
metadata = self.forward_metadata
|
||||
causal = not layer.is_cross_attention
|
||||
assert causal, "NSA is causal only"
|
||||
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
|
||||
if NSA_FUSE_TOPK:
|
||||
page_table_1 = topk_indices
|
||||
else:
|
||||
page_table_1 = transform_index_page_table_decode(
|
||||
page_table=metadata.page_table_1,
|
||||
topk_indices=topk_indices,
|
||||
page_size=1,
|
||||
)
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_prefill":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_prefill(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "flashmla_decode":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_decode(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
# TODO optimize args
|
||||
layer=layer,
|
||||
metadata=metadata,
|
||||
page_table_1=page_table_1,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "tilelang":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_tilelang(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "fa3":
|
||||
return self._forward_fa3(
|
||||
q_rope=q_rope,
|
||||
kv_cache=kv_cache,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
q_nope=q_nope,
|
||||
page_table=page_table_1,
|
||||
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
||||
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
||||
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
||||
sm_scale=layer.scaling,
|
||||
logit_cap=layer.logit_cap,
|
||||
page_size=1,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "aiter":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_aiter(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
layer=layer,
|
||||
metadata=metadata,
|
||||
bs=forward_batch.batch_size,
|
||||
)
|
||||
|
||||
else:
|
||||
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
|
||||
|
||||
def _forward_fa3(
|
||||
self,
|
||||
q_rope: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
q_nope: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
sm_scale: float,
|
||||
logit_cap: float,
|
||||
page_size: int,
|
||||
) -> torch.Tensor:
|
||||
k_rope_cache = kv_cache[:, :, v_head_dim:]
|
||||
c_kv_cache = kv_cache[:, :, :v_head_dim]
|
||||
qk_rope_dim = k_rope_cache.shape[-1]
|
||||
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
|
||||
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
softcap=logit_cap,
|
||||
return_softmax_lse=False,
|
||||
num_splits=self.num_splits,
|
||||
)
|
||||
return o # type: ignore
|
||||
|
||||
def _forward_flashmla_prefill(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
page_table_1: torch.Tensor,
|
||||
sm_scale: float,
|
||||
) -> torch.Tensor:
|
||||
from flash_mla import flash_mla_sparse_fwd
|
||||
|
||||
o, _, _ = flash_mla_sparse_fwd(
|
||||
q=q_all,
|
||||
kv=kv_cache,
|
||||
indices=page_table_1.unsqueeze(1),
|
||||
sm_scale=sm_scale,
|
||||
d_v=v_head_dim,
|
||||
)
|
||||
return o
|
||||
|
||||
def _forward_flashmla_decode(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
sm_scale: float,
|
||||
layer,
|
||||
metadata: NSAMetadata,
|
||||
page_table_1,
|
||||
) -> torch.Tensor:
|
||||
from flash_mla import flash_mla_with_kvcache
|
||||
|
||||
cache_seqlens = metadata.nsa_cache_seqlens_int32
|
||||
|
||||
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
|
||||
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
|
||||
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
|
||||
assert self.real_page_size == 64, "only page size 64 is supported"
|
||||
|
||||
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not self.nsa_kv_cache_store_fp8:
|
||||
# inefficiently quantize the whole cache
|
||||
kv_cache = quantize_k_cache(kv_cache)
|
||||
|
||||
indices = page_table_1.unsqueeze(1)
|
||||
assert (
|
||||
indices.shape[-1] == self.nsa_index_topk
|
||||
) # requirement of FlashMLA decode kernel
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q_all,
|
||||
k_cache=kv_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
head_dim_v=v_head_dim,
|
||||
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
|
||||
num_splits=metadata.flashmla_metadata.num_splits,
|
||||
softmax_scale=sm_scale,
|
||||
indices=indices,
|
||||
# doc says it is not used, but if pass in None then error
|
||||
block_table=torch.empty(
|
||||
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
|
||||
),
|
||||
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
)
|
||||
return o
|
||||
|
||||
def _forward_tilelang(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
page_table_1: torch.Tensor,
|
||||
sm_scale: float,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
|
||||
|
||||
return tilelang_sparse_fwd(
|
||||
q=q_all,
|
||||
kv=kv_cache,
|
||||
indices=page_table_1.unsqueeze(1),
|
||||
sm_scale=sm_scale,
|
||||
d_v=v_head_dim,
|
||||
)
|
||||
|
||||
def _forward_aiter(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
page_table_1: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
metadata: NSAMetadata,
|
||||
bs: int,
|
||||
) -> torch.Tensor:
|
||||
q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
if layer.head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
kv_indptr = self.kv_indptr
|
||||
|
||||
non_minus1_mask = page_table_1 != -1
|
||||
non_minus1_counts = non_minus1_mask.sum(dim=1)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0)
|
||||
|
||||
kv_indices = page_table_1[page_table_1 != -1]
|
||||
|
||||
mla_decode_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
kv_cache.view(-1, 1, 1, layer.head_dim),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
metadata.cu_seqlens_q,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.max_seq_len_q,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
# kv_cache = kv_cache.view(-1, 1, layer.head_dim)
|
||||
return o
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
return 1
|
||||
|
||||
def get_indexer_metadata(
|
||||
self, layer_id: int, forward_batch: ForwardBatch
|
||||
) -> NSAIndexerMetadata:
|
||||
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
|
||||
|
||||
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
|
||||
from flash_mla import get_mla_metadata
|
||||
|
||||
flashmla_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=cache_seqlens,
|
||||
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
|
||||
# but the name looks like need seq_len_q?
|
||||
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
|
||||
num_heads_k=1,
|
||||
num_heads_q=self.num_q_heads,
|
||||
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
topk=self.nsa_index_topk,
|
||||
)
|
||||
|
||||
return NSAFlashMLAMetadata(
|
||||
flashmla_metadata=flashmla_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
@@ -813,45 +813,69 @@ class DeepEPMoE(EPMoE):
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
else:
|
||||
# dynamic quant
|
||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||
hidden_states.device
|
||||
)
|
||||
if self.w13_weight.dtype != torch.int8:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||
# per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight.permute(0, 2, 1)],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
else:
|
||||
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -860,47 +884,72 @@ class DeepEPMoE(EPMoE):
|
||||
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
||||
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
group_list = group_list.to(torch.int64)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
if self.w13_weight.dtype != torch.int8:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||
# per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight.permute(0, 2, 1)],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
else:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=per_token_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=per_token_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -112,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_custom_logit_processor",
|
||||
"disaggregation_mode",
|
||||
"enable_deterministic_inference",
|
||||
"nsa_prefill",
|
||||
"nsa_decode",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
|
||||
@@ -76,35 +76,49 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
num_new_pages = get_num_new_pages(
|
||||
seq_lens=seq_lens_cpu,
|
||||
page_size=self.page_size,
|
||||
prefix_lens=prefix_lens_cpu,
|
||||
)
|
||||
if self.need_sort and num_new_pages > len(self.free_pages):
|
||||
num_new_pages = (
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
).sum()
|
||||
num_new_pages_item = num_new_pages.item()
|
||||
if self.need_sort and num_new_pages_item > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
if num_new_pages_item > len(self.free_pages):
|
||||
return None
|
||||
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
self.device,
|
||||
)
|
||||
if num_new_pages_item < 200:
|
||||
import sgl_kernel_npu
|
||||
|
||||
torch.ops.npu.alloc_extend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
self.page_size,
|
||||
out_indices,
|
||||
num_new_pages,
|
||||
)
|
||||
|
||||
else:
|
||||
alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
self.free_pages = self.free_pages[num_new_pages_item:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
|
||||
@@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
||||
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
"""
|
||||
@@ -1030,6 +1032,8 @@ class MLATokenToKVPool(KVCache):
|
||||
enable_memory_saver: bool,
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
use_nsa: bool = False,
|
||||
override_kv_cache_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
size,
|
||||
@@ -1044,6 +1048,14 @@ class MLATokenToKVPool(KVCache):
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.use_nsa = use_nsa
|
||||
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
|
||||
# TODO do not hardcode
|
||||
self.kv_cache_dim = (
|
||||
656
|
||||
if self.use_nsa and self.nsa_kv_cache_store_fp8
|
||||
else (kv_lora_rank + qk_rope_head_dim)
|
||||
)
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
@@ -1067,7 +1079,7 @@ class MLATokenToKVPool(KVCache):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
(size + page_size, 1, self.kv_cache_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
@@ -1130,6 +1142,7 @@ class MLATokenToKVPool(KVCache):
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
@@ -1147,16 +1160,28 @@ class MLATokenToKVPool(KVCache):
|
||||
cache_k_rope: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k_nope.dtype != self.dtype:
|
||||
cache_k_nope = cache_k_nope.to(self.dtype)
|
||||
cache_k_rope = cache_k_rope.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
||||
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
||||
|
||||
set_mla_kv_buffer_triton(
|
||||
self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
|
||||
)
|
||||
if self.use_nsa and self.nsa_kv_cache_store_fp8:
|
||||
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
|
||||
# TODO no need to cat
|
||||
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
|
||||
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||
else:
|
||||
if cache_k_nope.dtype != self.dtype:
|
||||
cache_k_nope = cache_k_nope.to(self.dtype)
|
||||
cache_k_rope = cache_k_rope.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
||||
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
||||
|
||||
set_mla_kv_buffer_triton(
|
||||
self.kv_buffer[layer_id - self.start_layer],
|
||||
loc,
|
||||
cache_k_nope,
|
||||
cache_k_rope,
|
||||
)
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
@@ -1186,6 +1211,103 @@ class MLATokenToKVPool(KVCache):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class NSATokenToKVPool(MLATokenToKVPool):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
kv_lora_rank: int,
|
||||
dtype: torch.dtype,
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
index_head_dim: int,
|
||||
enable_memory_saver: bool,
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
size,
|
||||
page_size,
|
||||
dtype,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
layer_num,
|
||||
device,
|
||||
enable_memory_saver,
|
||||
start_layer,
|
||||
end_layer,
|
||||
use_nsa=True,
|
||||
)
|
||||
# self.index_k_dtype = torch.float8_e4m3fn
|
||||
# self.index_k_scale_dtype = torch.float32
|
||||
self.index_head_dim = index_head_dim
|
||||
# num head == 1 and head dim == 128 for index_k in NSA
|
||||
assert index_head_dim == 128
|
||||
|
||||
self.quant_block_size = 128
|
||||
|
||||
assert self.page_size == 64
|
||||
self.index_k_with_scale_buffer = [
|
||||
torch.zeros(
|
||||
# Layout:
|
||||
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
||||
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
||||
# data: for page i,
|
||||
# * buf[i, :page_size * head_dim] for fp8 data
|
||||
# * buf[i, page_size * head_dim:].view(float32) for scale
|
||||
(
|
||||
(size + page_size + 1) // self.page_size,
|
||||
self.page_size
|
||||
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
||||
|
||||
def get_index_k_continuous(
|
||||
self,
|
||||
layer_id: int,
|
||||
seq_len: int,
|
||||
page_indices: torch.Tensor,
|
||||
):
|
||||
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
||||
return index_buf_accessor.GetK.execute(
|
||||
self, buf, seq_len=seq_len, page_indices=page_indices
|
||||
)
|
||||
|
||||
def get_index_k_scale_continuous(
|
||||
self,
|
||||
layer_id: int,
|
||||
seq_len: int,
|
||||
page_indices: torch.Tensor,
|
||||
):
|
||||
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
||||
return index_buf_accessor.GetS.execute(
|
||||
self, buf, seq_len=seq_len, page_indices=page_indices
|
||||
)
|
||||
|
||||
# TODO rename later (currently use diff name to avoid confusion)
|
||||
def set_index_k_and_scale_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
index_k_scale: torch.Tensor,
|
||||
) -> None:
|
||||
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
||||
index_buf_accessor.SetKAndS.execute(
|
||||
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
||||
)
|
||||
|
||||
|
||||
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1194,6 +1316,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
dtype: torch.dtype,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
index_head_dim: Optional[int],
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
@@ -1213,6 +1336,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.index_head_dim = index_head_dim
|
||||
|
||||
self.custom_mem_pool = None
|
||||
|
||||
@@ -1240,6 +1364,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if self.index_head_dim is not None:
|
||||
self.index_k_buffer = torch.zeros(
|
||||
(
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
1,
|
||||
self.index_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self._finalize_allocation_log(size)
|
||||
|
||||
@@ -1251,6 +1387,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
kv_size_bytes += get_tensor_size_bytes(k_cache)
|
||||
for v_cache in self.v_buffer:
|
||||
kv_size_bytes += get_tensor_size_bytes(v_cache)
|
||||
if self.index_head_dim is not None:
|
||||
assert hasattr(self, "index_k_buffer")
|
||||
for index_k_cache in self.index_k_buffer:
|
||||
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
||||
return kv_size_bytes
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
@@ -1277,6 +1417,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.v_buffer[layer_id - self.start_layer]
|
||||
|
||||
def get_index_k_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.index_k_buffer[layer_id - self.start_layer]
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
||||
@@ -1289,6 +1437,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
||||
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
if self.index_head_dim is not None:
|
||||
kv_data_ptrs += [
|
||||
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
|
||||
]
|
||||
kv_data_lens += [
|
||||
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
kv_item_lens += [
|
||||
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def set_kv_buffer(
|
||||
@@ -1325,6 +1483,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
||||
)
|
||||
|
||||
def set_index_k_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
):
|
||||
if index_k.dtype != self.dtype:
|
||||
index_k = index_k.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
index_k = index_k.view(self.store_dtype)
|
||||
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
self.index_k_buffer[layer_id - self.start_layer].view(
|
||||
-1, 1, self.index_head_dim
|
||||
),
|
||||
loc.view(-1, 1),
|
||||
index_k.view(-1, 1, self.index_head_dim),
|
||||
)
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
|
||||
@@ -522,6 +522,7 @@ class CudaGraphRunner:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
if self.is_encoder_decoder:
|
||||
@@ -592,6 +593,7 @@ class CudaGraphRunner:
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
next_token_logits_buffer=next_token_logits_buffer,
|
||||
orig_seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
|
||||
@@ -293,6 +293,7 @@ class ForwardBatch:
|
||||
# For padding
|
||||
padded_static_len: int = -1 # -1 if not padded
|
||||
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
|
||||
num_token_non_padded_cpu: int = None
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
@@ -354,6 +355,7 @@ class ForwardBatch:
|
||||
ret.num_token_non_padded = torch.tensor(
|
||||
len(batch.input_ids), dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.num_token_non_padded_cpu = len(batch.input_ids)
|
||||
|
||||
# For MLP sync
|
||||
if batch.global_num_tokens is not None:
|
||||
|
||||
@@ -31,7 +31,12 @@ import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.configs.model_config import (
|
||||
AttentionArch,
|
||||
ModelConfig,
|
||||
get_nsa_index_head_dim,
|
||||
is_deepseek_nsa,
|
||||
)
|
||||
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.distributed import (
|
||||
@@ -96,6 +101,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
HybridReqToTokenPool,
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
NSATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
@@ -157,6 +163,7 @@ MLA_ATTENTION_BACKENDS = [
|
||||
"cutlass_mla",
|
||||
"trtllm_mla",
|
||||
"ascend",
|
||||
"nsa",
|
||||
]
|
||||
|
||||
|
||||
@@ -1547,6 +1554,7 @@ class ModelRunner:
|
||||
assert self.is_draft_worker
|
||||
|
||||
# Initialize token_to_kv_pool
|
||||
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
|
||||
if self.server_args.attention_backend == "ascend":
|
||||
if self.use_mla_backend:
|
||||
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
||||
@@ -1555,6 +1563,7 @@ class ModelRunner:
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
index_head_dim=self.model_config.index_head_dim,
|
||||
layer_num=self.num_effective_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
@@ -1574,7 +1583,22 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
elif self.use_mla_backend and is_nsa_model:
|
||||
self.token_to_kv_pool = NSATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.num_effective_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
|
||||
)
|
||||
elif self.use_mla_backend:
|
||||
assert not is_nsa_model
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
|
||||
@@ -75,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
|
||||
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
||||
|
||||
# Replay
|
||||
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
|
||||
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
||||
thread.start()
|
||||
self.graphs[self.bs].replay()
|
||||
thread.join()
|
||||
if self.model_runner.model_config.index_head_dim is None:
|
||||
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
|
||||
self.bs - self.raw_bs
|
||||
)
|
||||
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
||||
thread.start()
|
||||
self.graphs[self.bs].replay()
|
||||
thread.join()
|
||||
else:
|
||||
self.graphs[self.bs].replay()
|
||||
|
||||
output = self.output_buffers[self.bs]
|
||||
if isinstance(output, LogitsProcessorOutput):
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
||||
"""Inference-only DeepseekV2 model."""
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
@@ -25,10 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
from sglang.srt.configs.model_config import (
|
||||
get_nsa_index_head_dim,
|
||||
get_nsa_index_n_heads,
|
||||
get_nsa_index_topk,
|
||||
is_deepseek_nsa,
|
||||
)
|
||||
from sglang.srt.debug_utils.dumper import dumper
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
@@ -48,6 +55,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
||||
NPUFusedMLAPreprocess,
|
||||
is_mla_preprocess_enabled,
|
||||
)
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
|
||||
from sglang.srt.layers.communicator import (
|
||||
LayerCommunicator,
|
||||
LayerScatterModes,
|
||||
@@ -172,10 +180,13 @@ elif _is_hip:
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton as awq_dequantize,
|
||||
)
|
||||
elif _is_npu:
|
||||
import custom_ops
|
||||
import sgl_kernel_npu
|
||||
import torch_npu
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
@@ -184,6 +195,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
||||
"fa3",
|
||||
"nsa",
|
||||
"flashinfer",
|
||||
"cutlass_mla",
|
||||
"trtllm_mla",
|
||||
@@ -204,6 +216,9 @@ class AttnForwardMethod(IntEnum):
|
||||
# Use absorbed multi-latent attention
|
||||
MLA = auto()
|
||||
|
||||
# Use Deepseek V3.2 sparse multi-latent attention
|
||||
NPU_MLA_SPARSE = auto()
|
||||
|
||||
# Use multi-head attention, but with KV cache chunked.
|
||||
# This method can avoid OOM when prefix lengths are long.
|
||||
MHA_CHUNKED_KV = auto()
|
||||
@@ -246,9 +261,15 @@ def handle_attention_ascend(attn, forward_batch):
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
if hasattr(attn, "indexer"):
|
||||
return AttnForwardMethod.NPU_MLA_SPARSE
|
||||
else:
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
if hasattr(attn, "indexer"):
|
||||
return AttnForwardMethod.NPU_MLA_SPARSE
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def _get_sum_extend_prefix_lens(forward_batch):
|
||||
@@ -267,7 +288,9 @@ def _is_extend_without_speculative(forward_batch):
|
||||
)
|
||||
|
||||
|
||||
def _handle_attention_backend(attn, forward_batch, backend_name):
|
||||
def _handle_attention_backend(
|
||||
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
|
||||
):
|
||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||
disable_ragged = (
|
||||
backend_name in ["flashinfer", "flashmla"]
|
||||
@@ -333,6 +356,10 @@ def handle_attention_aiter(attn, forward_batch):
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def handle_attention_nsa(attn, forward_batch):
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def handle_attention_triton(attn, forward_batch):
|
||||
if (
|
||||
_is_extend_without_speculative(forward_batch)
|
||||
@@ -1005,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
# For tensor parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -1042,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
|
||||
self.use_nsa = is_deepseek_nsa(config)
|
||||
if self.use_nsa:
|
||||
self.indexer = Indexer(
|
||||
hidden_size=hidden_size,
|
||||
index_n_heads=get_nsa_index_n_heads(config),
|
||||
index_head_dim=get_nsa_index_head_dim(config),
|
||||
rope_head_dim=qk_rope_head_dim,
|
||||
index_topk=get_nsa_index_topk(config),
|
||||
q_lora_rank=q_lora_rank,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
scale_fmt="ue8m0",
|
||||
block_size=128,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=add_prefix("indexer", prefix),
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
alt_stream=alt_stream,
|
||||
)
|
||||
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
@@ -1064,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
@@ -1193,8 +1241,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
|
||||
if self.is_mla_preprocess_enabled:
|
||||
assert (
|
||||
quant_config.get_name() == "w8a8_int8"
|
||||
), "MLA Preprocess only works with W8A8Int8"
|
||||
quant_config is None or quant_config.get_name() == "w8a8_int8"
|
||||
), "MLA Preprocess only works with Unquant or W8A8Int8"
|
||||
self.mla_preprocess = None
|
||||
|
||||
def dispatch_attn_forward_method(
|
||||
@@ -1272,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return hidden_states, None, forward_batch, None
|
||||
|
||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||
|
||||
if attn_forward_method == AttnForwardMethod.MHA:
|
||||
inner_state = self.forward_normal_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
@@ -1304,6 +1351,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
inner_state = self.mla_preprocess.forward(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
||||
inner_state = self.forward_npu_sparse_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
@@ -1329,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return self.forward_normal_chunked_kv_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
return self.forward_absorb_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
||||
return self.forward_npu_sparse_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
||||
@@ -1424,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
q_lora = None
|
||||
if self.q_lora_rank is not None:
|
||||
if (
|
||||
(not isinstance(hidden_states, tuple))
|
||||
@@ -1462,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
|
||||
# q_lora needed by indexer
|
||||
if self.use_nsa:
|
||||
q_lora = q
|
||||
|
||||
k_nope = k_nope.unsqueeze(1)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
@@ -1527,14 +1585,41 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
||||
not _use_aiter or not _is_gfx95_supported
|
||||
not _use_aiter or not _is_gfx95_supported or self.use_nsa
|
||||
):
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||
topk_indices = None
|
||||
if q_lora is not None:
|
||||
topk_indices = self.indexer(
|
||||
x=hidden_states,
|
||||
q_lora=q_lora,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
layer_id=self.layer_id,
|
||||
)
|
||||
|
||||
return (
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
topk_indices,
|
||||
)
|
||||
|
||||
def forward_absorb_core(
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||
self,
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
topk_indices,
|
||||
):
|
||||
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
||||
extra_args = {}
|
||||
@@ -1543,6 +1628,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
||||
"is_neox": self.rotary_emb.is_neox_style,
|
||||
}
|
||||
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
@@ -1551,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_rope=q_pe,
|
||||
k_rope=k_pe,
|
||||
**extra_args,
|
||||
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
||||
)
|
||||
else:
|
||||
if _use_aiter_gfx95:
|
||||
@@ -1570,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
|
||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||
attn_output = self.attn_mqa(
|
||||
q,
|
||||
k,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
||||
)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
@@ -1652,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward_npu_sparse_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
"""
|
||||
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
|
||||
"""
|
||||
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
|
||||
if self.mla_preprocess is None:
|
||||
self.mla_preprocess = NPUFusedMLAPreprocess(
|
||||
self.fused_qkv_a_proj_with_mqa,
|
||||
self.q_a_layernorm,
|
||||
self.kv_a_layernorm,
|
||||
self.q_b_proj,
|
||||
self.w_kc,
|
||||
self.rotary_emb,
|
||||
self.layer_id,
|
||||
self.num_local_heads,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
)
|
||||
(
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
) = self.mla_preprocess.forward(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
|
||||
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
||||
q, _ = fused_qkv_a_proj_out.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q_lora = self.q_a_layernorm(q)
|
||||
else:
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if (
|
||||
(not isinstance(hidden_states, tuple))
|
||||
and hidden_states.shape[0] <= 16
|
||||
and self.use_min_latency_fused_a_gemm
|
||||
):
|
||||
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
||||
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
||||
)
|
||||
else:
|
||||
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
||||
q, latent_cache = fused_qkv_a_proj_out.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
|
||||
# overlap qk norm
|
||||
if self.alt_stream is not None and get_is_capture_mode():
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
q = self.q_a_layernorm(q)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
||||
q, k_nope = fused_rms_mxfp4_quant(
|
||||
q,
|
||||
self.q_a_layernorm.weight,
|
||||
self.q_a_layernorm.variance_epsilon,
|
||||
k_nope,
|
||||
self.kv_a_layernorm.weight,
|
||||
self.kv_a_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
|
||||
q_lora = q.clone() # required for topk_indices
|
||||
k_nope = k_nope.unsqueeze(1)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
|
||||
q_nope, q_pe = q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
q_nope.transpose(0, 1)
|
||||
)
|
||||
)
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
(q_nope_val, q_nope_scale),
|
||||
(self.w_kc, self.w_scale_k),
|
||||
q_nope_out,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
q_nope_out = q_nope_out[:, :expected_m, :]
|
||||
elif _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
||||
x = q_nope.transpose(0, 1)
|
||||
q_nope_out = torch.empty(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
self.w_kc.shape[2],
|
||||
device=x.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
batched_gemm_afp4wfp4_pre_quant(
|
||||
x,
|
||||
self.w_kc.transpose(-2, -1),
|
||||
self.w_scale_k.transpose(-2, -1),
|
||||
torch.bfloat16,
|
||||
q_nope_out,
|
||||
)
|
||||
else:
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
)
|
||||
else:
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
||||
not _use_aiter or not _is_gfx95_supported
|
||||
):
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
# TODO: multi-stream indexer
|
||||
topk_indices = self.indexer(
|
||||
hidden_states, q_lora, positions, forward_batch, self.layer_id
|
||||
)
|
||||
|
||||
return (
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
topk_indices,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
)
|
||||
|
||||
def forward_npu_sparse_core(
|
||||
self,
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
topk_indices,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
):
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out.contiguous(),
|
||||
k_nope.contiguous(),
|
||||
k_nope.contiguous(),
|
||||
forward_batch,
|
||||
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
|
||||
q_rope=q_pe.contiguous(),
|
||||
k_rope=k_pe.contiguous(),
|
||||
topk_indices=topk_indices,
|
||||
)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
attn_bmm_output = torch.empty(
|
||||
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
|
||||
dtype=attn_output.dtype,
|
||||
device=attn_output.device,
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_decode():
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
torch.bmm(
|
||||
attn_output,
|
||||
self.w_vc,
|
||||
out=attn_bmm_output.view(
|
||||
-1, self.num_local_heads, self.v_head_dim
|
||||
).transpose(0, 1),
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.contiguous()
|
||||
torch.ops.npu.batch_matmul_transpose(
|
||||
attn_output, self.w_vc, attn_bmm_output
|
||||
)
|
||||
|
||||
attn_bmm_output = attn_bmm_output.reshape(
|
||||
-1, self.num_local_heads * self.v_head_dim
|
||||
)
|
||||
|
||||
output, _ = self.o_proj(attn_bmm_output)
|
||||
return output
|
||||
|
||||
def forward_absorb_fused_mla_rope_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -2134,7 +2442,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
zero_allocator: BumpAllocator,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
quant_format = (
|
||||
"mxfp4"
|
||||
if _is_gfx95_supported
|
||||
@@ -3099,6 +3406,7 @@ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
||||
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
||||
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
||||
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
|
||||
AttentionBackendRegistry.register("nsa", handle_attention_nsa)
|
||||
AttentionBackendRegistry.register("triton", handle_attention_triton)
|
||||
|
||||
|
||||
@@ -3106,4 +3414,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
|
||||
class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
|
||||
|
||||
@@ -91,6 +91,7 @@ ATTENTION_BACKEND_CHOICES = [
|
||||
"triton",
|
||||
"torch_native",
|
||||
"flex_attention",
|
||||
"nsa",
|
||||
# NVIDIA specific
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
@@ -116,6 +117,8 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
||||
|
||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
|
||||
|
||||
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
|
||||
|
||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
||||
|
||||
|
||||
@@ -284,6 +287,8 @@ class ServerArgs:
|
||||
sampling_backend: Optional[str] = None
|
||||
grammar_backend: Optional[str] = None
|
||||
mm_attention_backend: Optional[str] = None
|
||||
nsa_prefill: str = "flashmla_prefill"
|
||||
nsa_decode: str = "fa3"
|
||||
|
||||
# Speculative decoding
|
||||
speculative_algorithm: Optional[str] = None
|
||||
@@ -719,6 +724,8 @@ class ServerArgs:
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
def _handle_model_specific_adjustments(self):
|
||||
from sglang.srt.configs.model_config import is_deepseek_nsa
|
||||
|
||||
if parse_connector_type(self.model_path) == ConnectorType.INSTANCE:
|
||||
return
|
||||
|
||||
@@ -796,6 +803,48 @@ class ServerArgs:
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
if is_deepseek_nsa(hf_config):
|
||||
if (
|
||||
self.attention_backend is None
|
||||
and self.prefill_attention_backend is None
|
||||
and self.decode_attention_backend is None
|
||||
):
|
||||
self.attention_backend = "nsa"
|
||||
logger.warning("Set nsa attention backend for DeepSeek NSA.")
|
||||
|
||||
if not is_npu():
|
||||
self.enable_dp_attention = True
|
||||
self.dp_size = self.tp_size
|
||||
logger.warning("DP attention is enabled for DeepSeek NSA.")
|
||||
|
||||
self.page_size = 64
|
||||
logger.warning("Setting page size to 64 for DeepSeek NSA.")
|
||||
|
||||
self.mem_fraction_static = 0.8
|
||||
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
|
||||
|
||||
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
|
||||
import torch
|
||||
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major >= 10:
|
||||
self.kv_cache_dtype = "fp8_e4m3"
|
||||
logger.warning("Setting KV cache dtype to fp8.")
|
||||
|
||||
if self.kv_cache_dtype == "fp8_e4m3":
|
||||
self.nsa_prefill = "flashmla_decode"
|
||||
self.nsa_decode = "flashmla_decode"
|
||||
logger.warning(
|
||||
"Setting NSA backend to flashmla_decode for FP8 KV Cache."
|
||||
)
|
||||
|
||||
# Logging env vars for NSA
|
||||
from sglang.srt.layers.attention.nsa.utils import (
|
||||
print_nsa_bool_env_vars,
|
||||
)
|
||||
|
||||
print_nsa_bool_env_vars()
|
||||
|
||||
def _handle_sampling_backend(self):
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = (
|
||||
@@ -1023,6 +1072,7 @@ class ServerArgs:
|
||||
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if model_arch in [
|
||||
"DeepseekV32ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
"BailingMoeForCausalLM",
|
||||
@@ -1974,6 +2024,18 @@ class ServerArgs:
|
||||
default=ServerArgs.mm_attention_backend,
|
||||
help="Set multimodal attention backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nsa-prefill",
|
||||
default=ServerArgs.nsa_prefill,
|
||||
type=str,
|
||||
choices=NSA_CHOICES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nsa-decode",
|
||||
default=ServerArgs.nsa_decode,
|
||||
type=str,
|
||||
choices=NSA_CHOICES,
|
||||
)
|
||||
|
||||
# Speculative decoding
|
||||
parser.add_argument(
|
||||
@@ -3251,6 +3313,7 @@ def auto_choose_speculative_params(self: ServerArgs):
|
||||
# The default value for llama
|
||||
return (5, 4, 8)
|
||||
elif arch in [
|
||||
"DeepseekV32ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"GptOssForCausalLM",
|
||||
|
||||
@@ -705,6 +705,8 @@ class TboForwardBatchPreparer:
|
||||
extend_num_tokens=extend_num_tokens,
|
||||
attn_backend=output_attn_backend,
|
||||
num_token_non_padded=out_num_token_non_padded,
|
||||
# TODO: handle it when we need TBO + DeepSeek V3.2
|
||||
num_token_non_padded_cpu=None,
|
||||
tbo_split_seq_index=None,
|
||||
tbo_parent_token_range=(start_token_index, end_token_index),
|
||||
tbo_children=None,
|
||||
|
||||
@@ -471,7 +471,7 @@ def is_pin_memory_available() -> bool:
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
|
||||
def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
|
||||
|
||||
|
||||
def make_layers(
|
||||
@@ -482,7 +482,7 @@ def make_layers(
|
||||
prefix: str = "",
|
||||
return_tuple: bool = False,
|
||||
offloader_kwargs: Dict[str, Any] = {},
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
) -> Tuple[torch.nn.Module, int, int]:
|
||||
"""Make a list of layers with the given layer function"""
|
||||
# circula imports
|
||||
from sglang.srt.distributed import get_pp_indices
|
||||
|
||||
@@ -123,6 +123,38 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
return config
|
||||
|
||||
|
||||
# Temporary hack for DeepSeek-V3.2 model
|
||||
def _load_deepseek_v32_model(
|
||||
model_path: str,
|
||||
trust_remote_code: bool = False,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# first get the local path
|
||||
local_path = download_from_hf(model_path)
|
||||
# then load the config file in json
|
||||
config_file = os.path.join(local_path, "config.json")
|
||||
if not os.path.exists(config_file):
|
||||
raise RuntimeError(f"Can't find config file in {local_path}.")
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
config_json = json.load(f)
|
||||
|
||||
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
|
||||
config_json["model_type"] = "deepseek_v3"
|
||||
|
||||
tmp_path = os.path.join(local_path, "_tmp_config_folder")
|
||||
os.makedirs(tmp_path, exist_ok=True)
|
||||
|
||||
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
|
||||
with open(unique_path, "w") as f:
|
||||
json.dump(config_json, f)
|
||||
|
||||
return AutoConfig.from_pretrained(
|
||||
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
|
||||
|
||||
@lru_cache_frozenset(maxsize=32)
|
||||
def get_config(
|
||||
model: str,
|
||||
@@ -144,9 +176,17 @@ def get_config(
|
||||
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
model = client.get_local_dir()
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
except ValueError as e:
|
||||
if not "deepseek_v32" in str(e):
|
||||
raise e
|
||||
config = _load_deepseek_v32_model(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
config.architectures is not None
|
||||
and config.architectures[0] == "Phi4MMForCausalLM"
|
||||
|
||||
57
python/sglang/test/get_logits_ut.py
Normal file
57
python/sglang/test/get_logits_ut.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
|
||||
super().__init__()
|
||||
self.weights_proj = nn.Linear(d_in, 1024)
|
||||
self.n_heads = n_heads
|
||||
self.softmax_scale = softmax_scale
|
||||
|
||||
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights = self.weights_proj(x)
|
||||
weights = weights * self.n_heads**-0.5
|
||||
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
||||
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
||||
return weights
|
||||
|
||||
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights = self.weights_proj(x)
|
||||
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
||||
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
|
||||
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
|
||||
return weights
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(0)
|
||||
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
|
||||
x = torch.randn(128, 2048) # batch=128, d_in=2048
|
||||
q_scale = torch.randn(128, 1)
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
for _ in range(1000):
|
||||
out_orig = model._get_logits_head_gate_orig(x, q_scale)
|
||||
print("Original version time:", time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
for _ in range(1000):
|
||||
out_opt = model._get_logits_head_gate_opt(x, q_scale)
|
||||
print("Optimized version time:", time.time() - start)
|
||||
|
||||
print("Difference:", (out_orig - out_opt).abs().max().item())
|
||||
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
"""
|
||||
Original version time: 0.49235057830810547
|
||||
Optimized version time: 0.4087331295013428
|
||||
Difference: 1.4901161193847656e-08
|
||||
"""
|
||||
Reference in New Issue
Block a user