add 2 kernels and optimize the calculation of topk_indices (#134)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -8,10 +8,30 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
||||
DeepseekV32IndexerMetadataBuilder,
|
||||
DeepseekV32IndexerPrefillMetadata)
|
||||
kv_spans_from_batches)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
context_q_lens: torch.Tensor
|
||||
context_q_lens_cpu: torch.Tensor
|
||||
context_k_lens: torch.Tensor
|
||||
context_k_lens_cpu: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
@@ -50,6 +70,43 @@ class DeepseekV32IndexerMetadata:
|
||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||
|
||||
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
|
||||
query_start_loc_cpu, seq_lens_cpu,
|
||||
block_table):
|
||||
prefill_query_start_loc = query_start_loc_cpu[
|
||||
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
|
||||
self.device)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = torch.cat([
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
|
||||
]).to(torch.int32).to(self.device)
|
||||
seq_len_q = token_end - token_start
|
||||
seq_len_kv = total_seq_lens
|
||||
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
|
||||
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
|
||||
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
|
||||
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
|
||||
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
context_q_lens=context_q_lens,
|
||||
context_q_lens_cpu=context_q_lens_cpu,
|
||||
context_k_lens=context_k_lens,
|
||||
context_k_lens_cpu=context_k_lens_cpu,
|
||||
)
|
||||
def kunlun_build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@@ -130,4 +187,5 @@ def kunlun_build(self,
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
|
||||
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
|
||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||
Reference in New Issue
Block a user