[Kernel] add topk_per_row to optimize the calculation of topk_indexes (#168)
Signed-off-by: chengxiaokang <chengxiaokang@baidu.com> Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -523,21 +523,23 @@ def sparse_attn_indexer_vllm_kunlun(
|
||||
context_k_lens_cpu=chunk.context_k_lens_cpu,
|
||||
)
|
||||
del k_fp8, k_scale
|
||||
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
|
||||
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||
chunk.cu_seqlen_ks)[:, None] < 0
|
||||
mask = torch.full_like(topk_indices,
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=topk_indices.device)
|
||||
mask = mask_lo & mask_hi
|
||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||
topk_indices_buffer[
|
||||
chunk.token_start:chunk.token_end, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start:chunk.token_end, :topk_tokens]
|
||||
|
||||
# when seqLens=None and next_n=None, it means that it is used to calculate topk_indices in prefill
|
||||
# refer to top_k_per_row_prefill:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L698
|
||||
torch.ops.xspeedgate_ops.topk_per_row(logits=logits,
|
||||
srcIndices=topk_indices,
|
||||
numRows=num_rows,
|
||||
stride0=logits.stride(0),
|
||||
stride1=logits.stride(1),
|
||||
topK=topk_tokens,
|
||||
rowStarts=chunk.cu_seqlen_ks,
|
||||
rowEnds=chunk.cu_seqlen_ke,
|
||||
seqLens=None,
|
||||
next_n=None)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
@@ -570,48 +572,33 @@ def sparse_attn_indexer_vllm_kunlun(
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
# padded query len
|
||||
current_device = padded_q_fp8_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
positions = torch.arange(max_model_len,
|
||||
device=current_device).unsqueeze(0).expand(
|
||||
batch_size * next_n, -1)
|
||||
row_indices = torch.arange(padded_num_tokens,
|
||||
device=current_device) // next_n
|
||||
next_n_offset = torch.arange(
|
||||
padded_num_tokens,
|
||||
device=padded_q_fp8_decode_tokens.device) % next_n
|
||||
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
# index_end_pos: [B * N, 1]
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
|
||||
del positions, mask
|
||||
topk_indices = torch.ops._C.fast_topkv2(logits, decode_metadata.seq_lens, topk_tokens) # [B * N, K]
|
||||
need_mask = decode_metadata.seq_lens_cpu.min() < topk_tokens
|
||||
if need_mask:
|
||||
positions_topk = torch.arange(topk_tokens,
|
||||
device=current_device).unsqueeze(0).expand(
|
||||
batch_size * next_n, -1)
|
||||
mask_topk = positions_topk <= index_end_pos
|
||||
topk_indices = topk_indices.masked_fill(~mask_topk, -1) # [B * N, K]
|
||||
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
topk_indices[topk_indices > index_end_pos] = -1
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
|
||||
# when row_starts=None and row_ends=None, it means that it is used to calculate topk_indices in decode
|
||||
# refer to top_k_per_row_decode:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L643
|
||||
torch.ops.xspeedgate_ops.topk_per_row(logits=logits,
|
||||
srcIndices=topk_indices,
|
||||
numRows=num_rows,
|
||||
stride0=logits.stride(0),
|
||||
stride1=logits.stride(1),
|
||||
topK=topk_tokens,
|
||||
rowStarts=None,
|
||||
rowEnds=None,
|
||||
seqLens=decode_metadata.seq_lens,
|
||||
next_n=next_n)
|
||||
|
||||
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens)
|
||||
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
# return topk_indices_buffer
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices
|
||||
)
|
||||
|
||||
|
||||
def sparse_attn_indexer_vllm_kunlun_fake(
|
||||
|
||||
Reference in New Issue
Block a user