Add kernels to optimize RoPE and the decoding stage (#143)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-23 10:29:52 +08:00
committed by GitHub
parent 9e13f23661
commit 0ce5f1a3f7
5 changed files with 74 additions and 115 deletions

View File

@@ -442,98 +442,6 @@ class DeepseekV2Attention(nn.Module):
output, _ = self.o_proj(attn_output)
return output
@torch.inference_mode()
def cp_gather_indexer_k_quant_cache(
kv_cache, # [num_blocks, block_size, head_dim + 1]
block_table, # [batch_size, num_blocks]
cu_seq_lens, # [batch_size + 1, ]
batch_size,
head_dim,
):
num_blocks, block_size, _ = kv_cache.shape
kv_cache = kv_cache.view(num_blocks, -1)
expected_value = []
expected_scale = []
for b in range(batch_size):
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
if s == 0:
continue
tot = cdiv(s, block_size)
blocks = block_table[b, :tot]
value = []
scale = []
full_block = torch.arange(tot - 1,
device=kv_cache.device,
dtype=torch.int32)
non_remaining_value = kv_cache[blocks[full_block], :block_size *
head_dim].view(-1, head_dim)
non_remaining_scale = kv_cache[blocks[full_block],
block_size * head_dim:].view(-1, 4)
remaining = s - (tot - 1) * block_size
value = torch.cat([
non_remaining_value,
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
],
dim=0)
scale = torch.cat([
non_remaining_scale,
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
remaining * 4].view(-1, 4)
],
dim=0)
expected_value.append(value)
expected_scale.append(scale)
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
gather_value = gather_value.view(torch.int8)
gather_scale = gather_scale.view(torch.float32)
return gather_value, gather_scale
@torch.inference_mode()
def kunlun_indexer_k_quant_cache(
k, #[num_tokens, head_dim]
kv_cache, # [num_blocks, cache_block_size, head_dim + 1]
slot_mapping, # [num_tokens]
quant_block_size,
):
num_blocks, cache_block_size, cache_stride = kv_cache.shape
# num_tokens, head_dim = k.shape
head_dim = k.shape[1]
num_tokens = slot_mapping.shape[0]
assert head_dim % quant_block_size == 0
kv_cache = kv_cache.view(num_blocks, -1)
k_fp8 = torch.empty(
k.shape,
device=k.device,
dtype=torch.int8,
)
k_scale = torch.empty(
[k.shape[0], 1],
device=k.device,
dtype=torch.float32,
)
torch.ops._C.quant2d(k, k_fp8, k_scale, force_sdnn=True)
k_scale /= 127
for token_idx in range(num_tokens):
slot_idx = slot_mapping[token_idx]
if slot_idx < 0:
continue
block_idx = slot_idx // cache_block_size
block_offset = slot_idx % cache_block_size
v_offset = block_offset * head_dim
kv_cache[block_idx, v_offset:v_offset + head_dim] = k_fp8[token_idx, :].view(torch.uint8).contiguous()
s_offset = cache_block_size * head_dim + block_offset * 4
kv_cache[block_idx, s_offset:s_offset + 4] = k_scale[token_idx, :].view(torch.uint8).contiguous()
kv_cache = kv_cache.view(num_blocks, cache_block_size, cache_stride)
@custom_op("vllm::sparse_attn_indexer_vllm_kunlun", mutates_args=())
def sparse_attn_indexer_vllm_kunlun(
hidden_states: torch.Tensor,
@@ -578,12 +486,6 @@ def sparse_attn_indexer_vllm_kunlun(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
# kunlun_indexer_k_quant_cache(
# k,
# kv_cache,
# slot_mapping,
# quant_block_size,
# )
torch.ops.xspeedgate_ops.indexer_k_quant_and_cache(
k,
@@ -685,8 +587,16 @@ def sparse_attn_indexer_vllm_kunlun(
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
topk_indices = torch.argsort(logits, dim=-1,
descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K]
del positions, mask
topk_indices = torch.ops._C.fast_topkv2(logits, decode_metadata.seq_lens, topk_tokens) # [B * N, K]
need_mask = decode_metadata.seq_lens_cpu.min() < topk_tokens
if need_mask:
positions_topk = torch.arange(topk_tokens,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
mask_topk = positions_topk <= index_end_pos
topk_indices = topk_indices.masked_fill(~mask_topk, -1) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)