Add kernels to optimize RoPE and the decoding stage (#143)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-23 10:29:52 +08:00
committed by GitHub
parent 9e13f23661
commit 0ce5f1a3f7
5 changed files with 74 additions and 115 deletions

View File

@@ -159,10 +159,8 @@ def kunlun_flash_mla_with_kvcache(
assert not causal, \
"causal must be `false` if sparse attention is enabled."
q_r, pe_cache = None, None # 当q_r和pe_cache为空时为packed模式
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
kv_lora_rank = head_dim_v
rope_head_dim = head_dim - kv_lora_rank
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
dtype=q.dtype, device=q.device)

View File

@@ -87,15 +87,13 @@ def int8_paged_mqa_logits(
batch_size, next_n, _, D = q_fp8.shape
num_blocks, block_size, _, _ = kv_cache_fp8.shape
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
k_val = k_val.view(-1,block_size, 1, D)
k_scale_list = []
for block_tables_idx in range(block_tables.shape[0]):
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
D:].view(-1, 4)
k_scale_list.append(k_scale_item)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1)
k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8)
k_val = k_val.view(-1, block_size, 1, D)
block_indices = block_tables.flatten()
k_scale = kv_cache_fp8[block_indices, block_size * D:].view(-1, 4).view(torch.float32)
k_scale = k_scale.view(-1, max_model_len)
kv_cache = [k_val, k_scale]
weights = weights.view(batch_size,next_n,-1)

View File

@@ -76,6 +76,24 @@ def vllm_kunlun_forward_cuda(
self.cos_sin_cache, self.is_neox_style)
return query, key
def vllm_ds_rope_forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return torch.ops.xspeedgate_ops.flashinfer_rotary_embedding(
positions=positions,
rotary_dim=self.rotary_dim,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox_style=self.is_neox_style,
query=query,
key=key,
offsets=offsets,
)
def apply_interleaved_rope(x: torch.Tensor,
mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
@@ -145,12 +163,10 @@ def vllm_kunlun_mrope_forward_cuda(
return query, key
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
DeepseekScalingRotaryEmbedding.forward = vllm_ds_rope_forward_cuda
DeepseekScalingRotaryEmbedding.forward_cuda = vllm_ds_rope_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda