[Spec Decoding] Support MTP for dsv3.2 (#11652)
Co-authored-by: Paiiiiiiiiiiiiii <zengpai@baidu.com>
This commit is contained in:
@@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
|
||||
return (
|
||||
config.architectures is not None
|
||||
and config.architectures[0]
|
||||
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
|
||||
in [
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
"DeepseekV3ForCausalLMNextN",
|
||||
]
|
||||
and getattr(config, "index_topk", None) is not None
|
||||
)
|
||||
|
||||
|
||||
@@ -266,7 +266,10 @@ class Indexer(CustomOp):
|
||||
)
|
||||
|
||||
blocksize = page_size
|
||||
seqlens_32 = metadata.get_seqlens_int32()
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
seqlens_32 = metadata.get_seqlens_expanded()
|
||||
else:
|
||||
seqlens_32 = metadata.get_seqlens_int32()
|
||||
# NOTE(dark): 132 is SM count on H200/B200, not magic number
|
||||
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
|
||||
seqlens_32, blocksize, self.sm_count
|
||||
@@ -317,8 +320,9 @@ class Indexer(CustomOp):
|
||||
k_fp8_list = []
|
||||
k_scale_list = []
|
||||
ks_list = []
|
||||
ke_list = []
|
||||
offset = 0
|
||||
|
||||
seq_lens_expanded = metadata.get_seqlens_expanded()
|
||||
block_tables = metadata.get_page_table_64()
|
||||
|
||||
assert (
|
||||
@@ -341,30 +345,34 @@ class Indexer(CustomOp):
|
||||
)
|
||||
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
|
||||
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
|
||||
ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
|
||||
k_fp8_list.append(k_fp8)
|
||||
k_scale_list.append(k_scale)
|
||||
ks_list.append(ks)
|
||||
ke_list.append(ke)
|
||||
offset += extend_seq_len
|
||||
|
||||
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
|
||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
|
||||
kv_fp8 = (k_fp8, k_scale)
|
||||
ks = torch.cat(ks_list, dim=0)
|
||||
seq_lens_expanded = metadata.get_seqlens_expanded()
|
||||
ke = ks + seq_lens_expanded
|
||||
ke = torch.cat(ke_list, dim=0)
|
||||
|
||||
logits = deep_gemm.fp8_mqa_logits(
|
||||
q_fp8,
|
||||
q_fp8[:offset],
|
||||
kv_fp8,
|
||||
weights,
|
||||
weights[:offset],
|
||||
ks,
|
||||
ke,
|
||||
clean_logits=False,
|
||||
)
|
||||
|
||||
token_nums, _, _ = q_fp8.shape
|
||||
assert logits.shape[0] == len(seq_lens_expanded)
|
||||
topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||
|
||||
raw_topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||
topk_result = torch.full(
|
||||
(token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
|
||||
)
|
||||
topk_result[:offset] = raw_topk_result
|
||||
return topk_result
|
||||
|
||||
def forward_indexer(
|
||||
@@ -500,6 +508,8 @@ class Indexer(CustomOp):
|
||||
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
|
||||
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
|
||||
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
|
||||
if not forward_batch.out_cache_loc.is_contiguous():
|
||||
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
|
||||
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
|
||||
layer_id=layer_id,
|
||||
loc=forward_batch.out_cache_loc,
|
||||
@@ -521,7 +531,10 @@ class Indexer(CustomOp):
|
||||
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
|
||||
)
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
or forward_batch.forward_mode.is_target_verify()
|
||||
):
|
||||
topk_result = self._get_topk_paged(
|
||||
forward_batch, layer_id, q_fp8, weights, metadata
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_hip:
|
||||
@@ -148,7 +149,14 @@ NSA_DECODE_IMPL: _NSA_IMPL_T
|
||||
|
||||
|
||||
class NativeSparseAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
speculative_step_id=0,
|
||||
topk=0,
|
||||
speculative_num_steps=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.forward_metadata: NSAMetadata
|
||||
self.device = model_runner.device
|
||||
@@ -185,6 +193,14 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
# Speculative decoding
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.speculative_num_draft_tokens = (
|
||||
model_runner.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
self.speculative_step_id = speculative_step_id
|
||||
|
||||
def get_device_int32_arange(self, l: int) -> torch.Tensor:
|
||||
if l > len(self._arange_buf):
|
||||
next_pow_of_2 = 1 << (l - 1).bit_length()
|
||||
@@ -208,13 +224,15 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
batch_size = forward_batch.batch_size
|
||||
device = forward_batch.seq_lens.device
|
||||
|
||||
assert (
|
||||
forward_batch.spec_info is None
|
||||
), "Spec decoding is not supported for NSA backend now"
|
||||
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
draft_token_num = self.speculative_num_draft_tokens
|
||||
else:
|
||||
draft_token_num = 0
|
||||
|
||||
cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
|
||||
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
|
||||
page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, :max_seqlen_k
|
||||
]
|
||||
@@ -224,6 +242,41 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
max_seqlen_q = 1
|
||||
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
|
||||
seqlens_expanded = cache_seqlens_int32
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
max_seqlen_q = self.speculative_num_draft_tokens
|
||||
nsa_max_seqlen_q = self.speculative_num_draft_tokens
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size * self.speculative_num_draft_tokens + 1,
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size
|
||||
forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu
|
||||
|
||||
seqlens_int32_cpu = [
|
||||
self.speculative_num_draft_tokens + kv_len
|
||||
for kv_len in forward_batch.seq_lens_cpu.tolist()
|
||||
]
|
||||
seqlens_expanded = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
kv_len - qo_len + 1,
|
||||
kv_len + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
for qo_len, kv_len in zip(
|
||||
extend_seq_lens_cpu,
|
||||
seqlens_int32_cpu,
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
page_table = torch.repeat_interleave(
|
||||
page_table, repeats=self.speculative_num_draft_tokens, dim=0
|
||||
)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
assert (
|
||||
forward_batch.extend_seq_lens_cpu is not None
|
||||
@@ -232,7 +285,11 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
), "All of them must not be None"
|
||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
|
||||
assert forward_batch.extend_seq_lens is not None
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
|
||||
if (
|
||||
any(forward_batch.extend_prefix_lens_cpu)
|
||||
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
||||
):
|
||||
max_seqlen_q = max(extend_seq_lens_cpu)
|
||||
cu_seqlens_q = compute_cu_seqlens(
|
||||
forward_batch.extend_seq_lens.to(torch.int32)
|
||||
@@ -277,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
flashmla_metadata=(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
seq_len_q=1,
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
else None
|
||||
@@ -288,6 +345,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
nsa_seqlens_expanded=seqlens_expanded,
|
||||
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
|
||||
real_page_table=self._transform_table_1_to_real(page_table),
|
||||
nsa_max_seqlen_q=1,
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
@@ -302,7 +360,9 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
self.decode_cuda_graph_metadata: Dict = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cache_seqlens": torch.ones(
|
||||
max_num_tokens, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
@@ -311,7 +371,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
),
|
||||
# fake page_table for sparse_prefill
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
max_num_tokens,
|
||||
self.max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
@@ -319,9 +379,9 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
"flashmla_metadata": (
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=torch.ones(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
max_num_tokens, dtype=torch.int32, device=self.device
|
||||
),
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
seq_len_q=1,
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
else None
|
||||
@@ -339,50 +399,166 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
"""Initialize forward metadata for capturing CUDA graph."""
|
||||
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||
assert (
|
||||
spec_info is None
|
||||
), "Speculative decoding is not supported for NSA backend now"
|
||||
if forward_mode.is_decode_or_idle():
|
||||
# Normal Decode
|
||||
# Get sequence information
|
||||
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
|
||||
# Normal Decode
|
||||
# Get sequence information
|
||||
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
# Use max context length for seq_len_k
|
||||
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
max_seqlen_q = 1
|
||||
max_seqlen_k = page_table_1.shape[1]
|
||||
|
||||
# Use max context length for seq_len_k
|
||||
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
max_seq_len_k = page_table_1.shape[1]
|
||||
# Precompute page table
|
||||
# Precompute cumulative sequence lengths
|
||||
|
||||
# Precompute page table
|
||||
# Precompute cumulative sequence lengths
|
||||
# NOTE(dark): this is always arange, since we are decoding
|
||||
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
|
||||
seqlens_expanded = cache_seqlens_int32
|
||||
nsa_extend_seq_lens_list = [1] * num_tokens
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, num_tokens + 1))
|
||||
flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
flashmla_metadata = None
|
||||
elif forward_mode.is_target_verify():
|
||||
cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
|
||||
torch.int32
|
||||
)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
max_seqlen_q = 1
|
||||
page_table_1 = self.decode_cuda_graph_metadata["page_table"][
|
||||
: bs * self.speculative_num_draft_tokens, :
|
||||
]
|
||||
max_seqlen_k = page_table_1.shape[1]
|
||||
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
bs * self.speculative_num_draft_tokens + 1,
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
|
||||
|
||||
seqlens_int32_cpu = [
|
||||
self.speculative_num_draft_tokens + kv_len
|
||||
for kv_len in seq_lens.tolist()
|
||||
]
|
||||
seqlens_expanded = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
kv_len - qo_len + 1,
|
||||
kv_len + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for qo_len, kv_len in zip(
|
||||
extend_seq_lens_cpu,
|
||||
seqlens_int32_cpu,
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
seqlens_expanded, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
||||
|
||||
flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
flashmla_metadata = None
|
||||
elif forward_mode.is_draft_extend():
|
||||
cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
|
||||
torch.int32
|
||||
)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
max_seqlen_k = page_table_1.shape[1]
|
||||
|
||||
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
|
||||
extend_seq_lens = torch.full(
|
||||
(bs,),
|
||||
self.speculative_num_draft_tokens,
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
max_seqlen_q = max(extend_seq_lens_cpu)
|
||||
cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32))
|
||||
|
||||
seqlens_int32_cpu = [
|
||||
self.speculative_num_draft_tokens + kv_len
|
||||
for kv_len in seq_lens.tolist()
|
||||
]
|
||||
seqlens_expanded = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
kv_len - qo_len + 1,
|
||||
kv_len + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for qo_len, kv_len in zip(
|
||||
extend_seq_lens_cpu,
|
||||
seqlens_int32_cpu,
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
seqlens_expanded, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
nsa_extend_seq_lens_list = [1] * bs
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
||||
# As the DeepGemm is not support for q_len = 3/4 in Indexer and every token has independent topk_indices,
|
||||
# we made the Q shape [bs * speculative_num_draft_tokens, 1, head_nums, dim].
|
||||
# So seq_len_q is 1 for flashmla_metadata in target_verify and draft_extend mode.
|
||||
flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
flashmla_metadata = None
|
||||
|
||||
# NOTE(dark): this is always arange, since we are decoding
|
||||
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
||||
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
||||
real_page_table = self._transform_table_1_to_real(page_table_1)
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, bs + 1))
|
||||
flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
)
|
||||
else:
|
||||
flashmla_metadata = None
|
||||
|
||||
metadata = NSAMetadata(
|
||||
page_size=self.real_page_size,
|
||||
cache_seqlens_int32=cache_seqlens_int32,
|
||||
max_seq_len_q=1,
|
||||
max_seq_len_k=max_seq_len_k,
|
||||
max_seq_len_q=max_seqlen_q,
|
||||
max_seq_len_k=max_seqlen_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
page_table_1=page_table_1,
|
||||
@@ -390,9 +566,9 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
||||
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
||||
nsa_seqlens_expanded=cache_seqlens_int32,
|
||||
nsa_seqlens_expanded=seqlens_expanded,
|
||||
real_page_table=real_page_table,
|
||||
nsa_extend_seq_lens_list=[1] * bs,
|
||||
nsa_extend_seq_lens_list=nsa_extend_seq_lens_list,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
@@ -411,33 +587,119 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
):
|
||||
"""Initialize forward metadata for replaying CUDA graph."""
|
||||
assert seq_lens_cpu is not None
|
||||
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||
assert (
|
||||
spec_info is None
|
||||
), "Speculative decoding is not supported for NSA backend now"
|
||||
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
|
||||
# Normal Decode
|
||||
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = int(seq_lens_cpu.max().item())
|
||||
if forward_mode.is_decode_or_idle():
|
||||
# Normal Decode
|
||||
max_len = int(seq_lens_cpu.max().item())
|
||||
|
||||
cache_seqlens = seq_lens.to(torch.int32)
|
||||
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
page_indices = self.req_to_token[req_pool_indices, :max_len]
|
||||
metadata.page_table_1[:, :max_len].copy_(page_indices)
|
||||
cache_seqlens = seq_lens.to(torch.int32)
|
||||
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
page_indices = self.req_to_token[req_pool_indices, :max_len]
|
||||
metadata.page_table_1[:, :max_len].copy_(page_indices)
|
||||
nsa_cache_seqlens = compute_nsa_seqlens(
|
||||
cache_seqlens, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
||||
seqlens_expanded = cache_seqlens
|
||||
elif forward_mode.is_target_verify():
|
||||
max_seqlen_k = int(
|
||||
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to(
|
||||
torch.int32
|
||||
)
|
||||
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
|
||||
page_indices = torch.repeat_interleave(
|
||||
page_indices, repeats=self.speculative_num_draft_tokens, dim=0
|
||||
)
|
||||
metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
|
||||
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
|
||||
|
||||
seqlens_int32_cpu = [
|
||||
self.speculative_num_draft_tokens + kv_len
|
||||
for kv_len in seq_lens_cpu.tolist()
|
||||
]
|
||||
seqlens_expanded = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
kv_len - qo_len + 1,
|
||||
kv_len + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for qo_len, kv_len in zip(
|
||||
extend_seq_lens_cpu,
|
||||
seqlens_int32_cpu,
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
metadata.nsa_seqlens_expanded.copy_(seqlens_expanded)
|
||||
nsa_cache_seqlens = compute_nsa_seqlens(
|
||||
seqlens_expanded, self.nsa_index_topk
|
||||
)
|
||||
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
||||
elif forward_mode.is_draft_extend():
|
||||
max_seqlen_k = int(seq_lens_cpu.max().item())
|
||||
cache_seqlens = seq_lens.to(torch.int32)
|
||||
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
|
||||
metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
|
||||
extend_seq_lens_cpu = spec_info.accept_length[:bs].tolist()
|
||||
|
||||
seqlens_int32_cpu = [
|
||||
self.speculative_num_draft_tokens + kv_len
|
||||
for kv_len in seq_lens_cpu.tolist()
|
||||
]
|
||||
seqlens_expanded = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
kv_len - qo_len + 1,
|
||||
kv_len + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for qo_len, kv_len in zip(
|
||||
extend_seq_lens_cpu,
|
||||
seqlens_int32_cpu,
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
metadata.nsa_seqlens_expanded[: seqlens_expanded.size(0)].copy_(
|
||||
seqlens_expanded
|
||||
)
|
||||
nsa_cache_seqlens = compute_nsa_seqlens(
|
||||
seqlens_expanded, self.nsa_index_topk
|
||||
)
|
||||
metadata.nsa_cache_seqlens_int32[: seqlens_expanded.size(0)].copy_(
|
||||
nsa_cache_seqlens
|
||||
)
|
||||
seqlens_expanded_size = seqlens_expanded.size(0)
|
||||
assert (
|
||||
metadata.nsa_cache_seqlens_int32 is not None
|
||||
and metadata.nsa_cu_seqlens_k is not None
|
||||
and self.nsa_index_topk is not None
|
||||
)
|
||||
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
|
||||
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
||||
metadata.nsa_cu_seqlens_k[1:].copy_(
|
||||
|
||||
metadata.nsa_cu_seqlens_k[1 : 1 + seqlens_expanded_size].copy_(
|
||||
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
|
||||
@@ -451,10 +713,13 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
assert metadata.real_page_table is metadata.page_table_1
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
metadata.flashmla_metadata.copy_(
|
||||
flashmla_metadata = metadata.flashmla_metadata.slice(
|
||||
slice(0, seqlens_expanded_size + 1)
|
||||
)
|
||||
flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
seq_len_q=1,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -473,10 +738,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
), "NSA backend doesn't support speculative decoding"
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
@@ -884,3 +1146,58 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
flashmla_metadata=flashmla_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class NativeSparseAttnMultiStepBackend:
|
||||
|
||||
def __init__(
|
||||
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
||||
):
|
||||
self.model_runner = model_runner
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
NativeSparseAttnBackend(
|
||||
model_runner,
|
||||
speculative_step_id=i,
|
||||
topk=self.topk,
|
||||
speculative_num_steps=self.speculative_num_steps,
|
||||
)
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
forward_batch.batch_size * self.topk,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
seq_lens_sum=-1,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
@@ -48,6 +48,7 @@ class DraftBackendFactory:
|
||||
"flashmla": self._create_flashmla_decode_backend,
|
||||
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
||||
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
||||
"nsa": self._create_nsa_decode_backend,
|
||||
}
|
||||
|
||||
return self._create_backend(
|
||||
@@ -70,6 +71,7 @@ class DraftBackendFactory:
|
||||
"flashmla": self._create_flashmla_prefill_backend,
|
||||
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||
"nsa": self._create_nsa_prefill_backend,
|
||||
}
|
||||
backend_name = (
|
||||
"decode_attention_backend"
|
||||
@@ -82,6 +84,20 @@ class DraftBackendFactory:
|
||||
"EAGLE is not supported in attention backend {backend_type}",
|
||||
)
|
||||
|
||||
def _create_nsa_decode_backend(self):
|
||||
from sglang.srt.layers.attention.nsa_backend import (
|
||||
NativeSparseAttnMultiStepBackend,
|
||||
)
|
||||
|
||||
return NativeSparseAttnMultiStepBackend(
|
||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||
)
|
||||
|
||||
def _create_nsa_prefill_backend(self):
|
||||
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
||||
|
||||
return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||
|
||||
def _create_flashinfer_decode_backend(self):
|
||||
if not get_global_server_args().use_mla_backend:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
|
||||
@@ -81,6 +81,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.seq_lens_cpu = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs
|
||||
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
@@ -92,6 +93,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.seq_lens = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
self.out_cache_loc = torch.zeros(
|
||||
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
||||
)
|
||||
@@ -165,6 +167,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
# Graph inputs
|
||||
req_pool_indices = self.req_pool_indices[:num_seqs]
|
||||
seq_lens = self.seq_lens[:num_seqs]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_seqs]
|
||||
extend_seq_lens = self.extend_seq_lens[:num_seqs]
|
||||
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs]
|
||||
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
||||
positions = self.positions[:num_tokens]
|
||||
mrope_positions = self.mrope_positions[:, :num_tokens]
|
||||
@@ -227,6 +232,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
input_ids=None,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_seq_lens_cpu=extend_seq_lens_cpu,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
|
||||
@@ -78,6 +78,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.seq_lens_cpu = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs
|
||||
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
@@ -196,7 +197,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
extend_seq_lens = self.extend_seq_lens[:bs]
|
||||
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs]
|
||||
accept_length = self.accept_length[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
@@ -254,6 +257,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
next_token_logits_buffer=next_token_logits_buffer,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
@@ -271,6 +275,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
attn_backend=self.eagle_worker.draft_extend_attn_backend,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_seq_lens_cpu=extend_seq_lens_cpu,
|
||||
padded_static_len=self.padded_static_len,
|
||||
)
|
||||
|
||||
@@ -373,6 +378,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
|
||||
if forward_batch.extend_seq_lens_cpu is not None:
|
||||
self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu
|
||||
|
||||
if bs != raw_bs:
|
||||
forward_batch.spec_info.positions = self.positions[:num_tokens]
|
||||
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
||||
|
||||
Reference in New Issue
Block a user