[Kernel] add topk_per_row to optimize the calculation of topk_indexes (#168)

Signed-off-by: chengxiaokang <chengxiaokang@baidu.com>
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-02-02 11:07:49 +08:00
committed by GitHub
parent 726cefb7a3
commit 6f12830839

View File

@@ -523,21 +523,23 @@ def sparse_attn_indexer_vllm_kunlun(
context_k_lens_cpu=chunk.context_k_lens_cpu,
)
del k_fp8, k_scale
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
chunk.cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[
chunk.token_start:chunk.token_end, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
chunk.token_start:chunk.token_end, :topk_tokens]
# when seqLens=None and next_n=None, it means that it is used to calculate topk_indices in prefill
# refer to top_k_per_row_prefillhttps://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L698
torch.ops.xspeedgate_ops.topk_per_row(logits=logits,
srcIndices=topk_indices,
numRows=num_rows,
stride0=logits.stride(0),
stride1=logits.stride(1),
topK=topk_tokens,
rowStarts=chunk.cu_seqlen_ks,
rowEnds=chunk.cu_seqlen_ke,
seqLens=None,
next_n=None)
if has_decode:
decode_metadata = attn_metadata.decode
@@ -570,48 +572,33 @@ def sparse_attn_indexer_vllm_kunlun(
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
positions = torch.arange(max_model_len,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
row_indices = torch.arange(padded_num_tokens,
device=current_device) // next_n
next_n_offset = torch.arange(
padded_num_tokens,
device=padded_q_fp8_decode_tokens.device) % next_n
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
del positions, mask
topk_indices = torch.ops._C.fast_topkv2(logits, decode_metadata.seq_lens, topk_tokens) # [B * N, K]
need_mask = decode_metadata.seq_lens_cpu.min() < topk_tokens
if need_mask:
positions_topk = torch.arange(topk_tokens,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
mask_topk = positions_topk <= index_end_pos
topk_indices = topk_indices.masked_fill(~mask_topk, -1) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
# when row_starts=None and row_ends=None, it means that it is used to calculate topk_indices in decode
# refer to top_k_per_row_decodehttps://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L643
torch.ops.xspeedgate_ops.topk_per_row(logits=logits,
srcIndices=topk_indices,
numRows=num_rows,
stride0=logits.stride(0),
stride1=logits.stride(1),
topK=topk_tokens,
rowStarts=None,
rowEnds=None,
seqLens=decode_metadata.seq_lens,
next_n=next_n)
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens)
topk_indices_buffer[:num_decode_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
# return topk_indices_buffer
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices
)
def sparse_attn_indexer_vllm_kunlun_fake(