add 2 kernels and optimize the calculation of topk_indices (#134)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-22 10:29:28 +08:00
committed by GitHub
parent c9f00c132c
commit 74d4f804e8
4 changed files with 108 additions and 32 deletions

View File

@@ -596,23 +596,33 @@ def sparse_attn_indexer_vllm_kunlun(
if has_prefill:
prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks:
k_fp8, k_scale = cp_gather_indexer_k_quant_cache(
kv_cache,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
head_dim,
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.int8)
k_scale = torch.empty([chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8)
torch.ops.xspeedgate_ops.cp_gather_indexer_k_quant_cache(
kv_cache=kv_cache,
dst_k=k_fp8,
dst_scale=k_scale,
block_table=chunk.block_table,
cu_seq_lens=chunk.cu_seq_lens,
)
logits = int8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
context_q_lens_xpu=chunk.context_q_lens,
context_q_lens_cpu=chunk.context_q_lens_cpu,
context_k_lens_xpu=chunk.context_k_lens,
context_k_lens_cpu=chunk.context_k_lens_cpu,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
del k_fp8, k_scale
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
@@ -675,8 +685,9 @@ def sparse_attn_indexer_vllm_kunlun(
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
topk_indices = logits.topk(topk_tokens,
dim=-1)[1].to(torch.int32) # [B * N, K]
topk_indices = torch.argsort(logits, dim=-1,
descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K