Add kernels to optimize RoPE and the decoding stage (#143)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -87,15 +87,13 @@ def int8_paged_mqa_logits(
|
||||
batch_size, next_n, _, D = q_fp8.shape
|
||||
num_blocks, block_size, _, _ = kv_cache_fp8.shape
|
||||
|
||||
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
|
||||
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
|
||||
k_val = k_val.view(-1,block_size, 1, D)
|
||||
k_scale_list = []
|
||||
for block_tables_idx in range(block_tables.shape[0]):
|
||||
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
|
||||
D:].view(-1, 4)
|
||||
k_scale_list.append(k_scale_item)
|
||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
|
||||
kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1)
|
||||
k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8)
|
||||
k_val = k_val.view(-1, block_size, 1, D)
|
||||
|
||||
block_indices = block_tables.flatten()
|
||||
k_scale = kv_cache_fp8[block_indices, block_size * D:].view(-1, 4).view(torch.float32)
|
||||
k_scale = k_scale.view(-1, max_model_len)
|
||||
kv_cache = [k_val, k_scale]
|
||||
|
||||
weights = weights.view(batch_size,next_n,-1)
|
||||
|
||||
Reference in New Issue
Block a user