[Kernel] add topk_per_row to optimize the calculation of topk_indexes (#168)
Signed-off-by: chengxiaokang <chengxiaokang@baidu.com> Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -523,21 +523,23 @@ def sparse_attn_indexer_vllm_kunlun(
|
|||||||
context_k_lens_cpu=chunk.context_k_lens_cpu,
|
context_k_lens_cpu=chunk.context_k_lens_cpu,
|
||||||
)
|
)
|
||||||
del k_fp8, k_scale
|
del k_fp8, k_scale
|
||||||
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
|
|
||||||
|
num_rows = logits.shape[0]
|
||||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
topk_indices = topk_indices_buffer[
|
||||||
mask_lo = topk_indices >= 0
|
chunk.token_start:chunk.token_end, :topk_tokens]
|
||||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
|
||||||
chunk.cu_seqlen_ks)[:, None] < 0
|
# when seqLens=None and next_n=None, it means that it is used to calculate topk_indices in prefill
|
||||||
mask = torch.full_like(topk_indices,
|
# refer to top_k_per_row_prefill:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L698
|
||||||
False,
|
torch.ops.xspeedgate_ops.topk_per_row(logits=logits,
|
||||||
dtype=torch.bool,
|
srcIndices=topk_indices,
|
||||||
device=topk_indices.device)
|
numRows=num_rows,
|
||||||
mask = mask_lo & mask_hi
|
stride0=logits.stride(0),
|
||||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
stride1=logits.stride(1),
|
||||||
topk_indices_buffer[
|
topK=topk_tokens,
|
||||||
chunk.token_start:chunk.token_end, :topk_indices.
|
rowStarts=chunk.cu_seqlen_ks,
|
||||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
rowEnds=chunk.cu_seqlen_ke,
|
||||||
|
seqLens=None,
|
||||||
|
next_n=None)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_metadata = attn_metadata.decode
|
decode_metadata = attn_metadata.decode
|
||||||
@@ -570,48 +572,33 @@ def sparse_attn_indexer_vllm_kunlun(
|
|||||||
decode_metadata.schedule_metadata,
|
decode_metadata.schedule_metadata,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
# padded query len
|
|
||||||
current_device = padded_q_fp8_decode_tokens.device
|
|
||||||
padded_num_tokens = batch_size * next_n
|
|
||||||
positions = torch.arange(max_model_len,
|
|
||||||
device=current_device).unsqueeze(0).expand(
|
|
||||||
batch_size * next_n, -1)
|
|
||||||
row_indices = torch.arange(padded_num_tokens,
|
|
||||||
device=current_device) // next_n
|
|
||||||
next_n_offset = torch.arange(
|
|
||||||
padded_num_tokens,
|
|
||||||
device=padded_q_fp8_decode_tokens.device) % next_n
|
|
||||||
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
|
||||||
next_n_offset).unsqueeze(1)
|
|
||||||
# index_end_pos: [B * N, 1]
|
|
||||||
mask = positions <= index_end_pos
|
|
||||||
# mask: [B * N, L]
|
|
||||||
logits = logits.masked_fill(~mask, float('-inf'))
|
|
||||||
|
|
||||||
del positions, mask
|
num_rows = logits.shape[0]
|
||||||
topk_indices = torch.ops._C.fast_topkv2(logits, decode_metadata.seq_lens, topk_tokens) # [B * N, K]
|
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||||
need_mask = decode_metadata.seq_lens_cpu.min() < topk_tokens
|
|
||||||
if need_mask:
|
# when row_starts=None and row_ends=None, it means that it is used to calculate topk_indices in decode
|
||||||
positions_topk = torch.arange(topk_tokens,
|
# refer to top_k_per_row_decode:https://github.com/vllm-project/vllm/blob/6a09612b2e0e09d037a220ea8115632b8084e008/csrc/sampler.cu#L643
|
||||||
device=current_device).unsqueeze(0).expand(
|
torch.ops.xspeedgate_ops.topk_per_row(logits=logits,
|
||||||
batch_size * next_n, -1)
|
srcIndices=topk_indices,
|
||||||
mask_topk = positions_topk <= index_end_pos
|
numRows=num_rows,
|
||||||
topk_indices = topk_indices.masked_fill(~mask_topk, -1) # [B * N, K]
|
stride0=logits.stride(0),
|
||||||
|
stride1=logits.stride(1),
|
||||||
# ensure we don't set indices for the top k
|
topK=topk_tokens,
|
||||||
# that is out of range(masked already)
|
rowStarts=None,
|
||||||
# this will happen if context length is shorter than K
|
rowEnds=None,
|
||||||
topk_indices[topk_indices > index_end_pos] = -1
|
seqLens=decode_metadata.seq_lens,
|
||||||
|
next_n=next_n)
|
||||||
|
|
||||||
|
|
||||||
if decode_metadata.requires_padding:
|
if decode_metadata.requires_padding:
|
||||||
# if padded, we need to unpack
|
# if padded, we need to unpack
|
||||||
# the topk indices removing padded tokens
|
# the topk indices removing padded tokens
|
||||||
topk_indices = unpack_seq_triton(
|
topk_indices = unpack_seq_triton(
|
||||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||||
decode_lens)
|
decode_lens)
|
||||||
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
topk_indices
|
||||||
|
)
|
||||||
# return topk_indices_buffer
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_attn_indexer_vllm_kunlun_fake(
|
def sparse_attn_indexer_vllm_kunlun_fake(
|
||||||
|
|||||||
Reference in New Issue
Block a user