From 6f128308393414bde671d7abdb2bf74942734c11 Mon Sep 17 00:00:00 2001 From: fromck <74886593+fromck@users.noreply.github.com> Date: Mon, 2 Feb 2026 11:07:49 +0800 Subject: [PATCH] [Kernel] add topk_per_row to optimize the calculation of topk_indexes (#168) Signed-off-by: chengxiaokang Co-authored-by: chengxiaokang --- vllm_kunlun/models/deepseek_v2.py | 87 +++++++++++++------------------ 1 file changed, 37 insertions(+), 50 deletions(-) diff --git a/vllm_kunlun/models/deepseek_v2.py b/vllm_kunlun/models/deepseek_v2.py index 52478ba..d65ab3e 100644 --- a/vllm_kunlun/models/deepseek_v2.py +++ b/vllm_kunlun/models/deepseek_v2.py @@ -523,21 +523,23 @@ def sparse_attn_indexer_vllm_kunlun( context_k_lens_cpu=chunk.context_k_lens_cpu, ) del k_fp8, k_scale - topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])] - - topk_indices -= chunk.cu_seqlen_ks[:, None] - mask_lo = topk_indices >= 0 - mask_hi = topk_indices - (chunk.cu_seqlen_ke - - chunk.cu_seqlen_ks)[:, None] < 0 - mask = torch.full_like(topk_indices, - False, - dtype=torch.bool, - device=topk_indices.device) - mask = mask_lo & mask_hi - topk_indices = topk_indices.masked_fill(~mask, -1) - topk_indices_buffer[ - chunk.token_start:chunk.token_end, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) + + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[ + chunk.token_start:chunk.token_end, :topk_tokens] + + # when seqLens=None and next_n=None, it means that it is used to calculate topk_indices in prefill + # refer to top_k_per_row_prefill:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L698 + torch.ops.xspeedgate_ops.topk_per_row(logits=logits, + srcIndices=topk_indices, + numRows=num_rows, + stride0=logits.stride(0), + stride1=logits.stride(1), + topK=topk_tokens, + rowStarts=chunk.cu_seqlen_ks, + rowEnds=chunk.cu_seqlen_ke, + seqLens=None, + next_n=None) if has_decode: decode_metadata = attn_metadata.decode @@ -570,48 +572,33 @@ def sparse_attn_indexer_vllm_kunlun( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) - # padded query len - current_device = padded_q_fp8_decode_tokens.device - padded_num_tokens = batch_size * next_n - positions = torch.arange(max_model_len, - device=current_device).unsqueeze(0).expand( - batch_size * next_n, -1) - row_indices = torch.arange(padded_num_tokens, - device=current_device) // next_n - next_n_offset = torch.arange( - padded_num_tokens, - device=padded_q_fp8_decode_tokens.device) % next_n - index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + - next_n_offset).unsqueeze(1) - # index_end_pos: [B * N, 1] - mask = positions <= index_end_pos - # mask: [B * N, L] - logits = logits.masked_fill(~mask, float('-inf')) - del positions, mask - topk_indices = torch.ops._C.fast_topkv2(logits, decode_metadata.seq_lens, topk_tokens) # [B * N, K] - need_mask = decode_metadata.seq_lens_cpu.min() < topk_tokens - if need_mask: - positions_topk = torch.arange(topk_tokens, - device=current_device).unsqueeze(0).expand( - batch_size * next_n, -1) - mask_topk = positions_topk <= index_end_pos - topk_indices = topk_indices.masked_fill(~mask_topk, -1) # [B * N, K] - - # ensure we don't set indices for the top k - # that is out of range(masked already) - # this will happen if context length is shorter than K - topk_indices[topk_indices > index_end_pos] = -1 + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + + # when row_starts=None and row_ends=None, it means that it is used to calculate topk_indices in decode + # refer to top_k_per_row_decode:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L643 + torch.ops.xspeedgate_ops.topk_per_row(logits=logits, + srcIndices=topk_indices, + numRows=num_rows, + stride0=logits.stride(0), + stride1=logits.stride(1), + topK=topk_tokens, + rowStarts=None, + rowEnds=None, + seqLens=decode_metadata.seq_lens, + next_n=next_n) + + if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens) - topk_indices_buffer[:num_decode_tokens, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) - - # return topk_indices_buffer + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) def sparse_attn_indexer_vllm_kunlun_fake(