add 2 kernels and optimize the calculation of topk_indices (#134)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -596,23 +596,33 @@ def sparse_attn_indexer_vllm_kunlun(
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8, k_scale = cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
chunk.num_reqs,
|
||||
head_dim,
|
||||
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=torch.int8)
|
||||
k_scale = torch.empty([chunk.total_seq_lens, 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8)
|
||||
torch.ops.xspeedgate_ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache=kv_cache,
|
||||
dst_k=k_fp8,
|
||||
dst_scale=k_scale,
|
||||
block_table=chunk.block_table,
|
||||
cu_seq_lens=chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = int8_mqa_logits(
|
||||
q_fp8[chunk.token_start:chunk.token_end],
|
||||
(k_fp8, k_scale),
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
weights[chunk.token_start:chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
context_q_lens_xpu=chunk.context_q_lens,
|
||||
context_q_lens_cpu=chunk.context_q_lens_cpu,
|
||||
context_k_lens_xpu=chunk.context_k_lens,
|
||||
context_k_lens_cpu=chunk.context_k_lens_cpu,
|
||||
)
|
||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
||||
dim=-1)[1]
|
||||
del k_fp8, k_scale
|
||||
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
|
||||
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||
@@ -675,8 +685,9 @@ def sparse_attn_indexer_vllm_kunlun(
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
topk_indices = logits.topk(topk_tokens,
|
||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
topk_indices = torch.argsort(logits, dim=-1,
|
||||
descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K]
|
||||
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
|
||||
Reference in New Issue
Block a user