Add kernels to optimize RoPE and the decoding stage (#143)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-23 10:29:52 +08:00
committed by GitHub
parent 9e13f23661
commit 0ce5f1a3f7
5 changed files with 74 additions and 115 deletions

View File

@@ -87,15 +87,13 @@ def int8_paged_mqa_logits(
batch_size, next_n, _, D = q_fp8.shape
num_blocks, block_size, _, _ = kv_cache_fp8.shape
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
k_val = k_val.view(-1,block_size, 1, D)
k_scale_list = []
for block_tables_idx in range(block_tables.shape[0]):
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
D:].view(-1, 4)
k_scale_list.append(k_scale_item)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1)
k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8)
k_val = k_val.view(-1, block_size, 1, D)
block_indices = block_tables.flatten()
k_scale = kv_cache_fp8[block_indices, block_size * D:].view(-1, 4).view(torch.float32)
k_scale = k_scale.view(-1, max_model_len)
kv_cache = [k_val, k_scale]
weights = weights.view(batch_size,next_n,-1)