add 2 kernels and optimize the calculation of topk_indices (#134)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-22 10:29:28 +08:00
committed by GitHub
parent c9f00c132c
commit 74d4f804e8
4 changed files with 108 additions and 32 deletions

View File

@@ -8,10 +8,30 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
DeepseekV32IndexerMetadataBuilder,
DeepseekV32IndexerPrefillMetadata)
kv_spans_from_batches)
logger = init_logger(__name__)
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
context_q_lens: torch.Tensor
context_q_lens_cpu: torch.Tensor
context_k_lens: torch.Tensor
context_k_lens_cpu: torch.Tensor
@dataclass
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
@@ -50,6 +70,43 @@ class DeepseekV32IndexerMetadata:
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
query_start_loc_cpu, seq_lens_cpu,
block_table):
prefill_query_start_loc = query_start_loc_cpu[
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
self.device)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
]).to(torch.int32).to(self.device)
seq_len_q = token_end - token_start
seq_len_kv = total_seq_lens
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
context_q_lens=context_q_lens,
context_q_lens_cpu=context_q_lens_cpu,
context_k_lens=context_k_lens,
context_k_lens_cpu=context_k_lens_cpu,
)
def kunlun_build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
@@ -130,4 +187,5 @@ def kunlun_build(self,
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
DeepseekV32IndexerMetadataBuilder.build = kunlun_build