Add kernels to optimize RoPE and the decoding stage (#143)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -159,10 +159,8 @@ def kunlun_flash_mla_with_kvcache(
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
|
||||
q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式
|
||||
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
|
||||
kv_lora_rank = head_dim_v
|
||||
rope_head_dim = head_dim - kv_lora_rank
|
||||
|
||||
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
|
||||
dtype=q.dtype, device=q.device)
|
||||
|
||||
@@ -87,15 +87,13 @@ def int8_paged_mqa_logits(
|
||||
batch_size, next_n, _, D = q_fp8.shape
|
||||
num_blocks, block_size, _, _ = kv_cache_fp8.shape
|
||||
|
||||
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
|
||||
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
|
||||
k_val = k_val.view(-1,block_size, 1, D)
|
||||
k_scale_list = []
|
||||
for block_tables_idx in range(block_tables.shape[0]):
|
||||
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
|
||||
D:].view(-1, 4)
|
||||
k_scale_list.append(k_scale_item)
|
||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
|
||||
kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1)
|
||||
k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8)
|
||||
k_val = k_val.view(-1, block_size, 1, D)
|
||||
|
||||
block_indices = block_tables.flatten()
|
||||
k_scale = kv_cache_fp8[block_indices, block_size * D:].view(-1, 4).view(torch.float32)
|
||||
k_scale = k_scale.view(-1, max_model_len)
|
||||
kv_cache = [k_val, k_scale]
|
||||
|
||||
weights = weights.view(batch_size,next_n,-1)
|
||||
|
||||
@@ -76,6 +76,24 @@ def vllm_kunlun_forward_cuda(
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def vllm_ds_rope_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return torch.ops.xspeedgate_ops.flashinfer_rotary_embedding(
|
||||
positions=positions,
|
||||
rotary_dim=self.rotary_dim,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=self.cos_sin_cache,
|
||||
is_neox_style=self.is_neox_style,
|
||||
query=query,
|
||||
key=key,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
def apply_interleaved_rope(x: torch.Tensor,
|
||||
mrope_section: list[int]) -> torch.Tensor:
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
@@ -145,12 +163,10 @@ def vllm_kunlun_mrope_forward_cuda(
|
||||
|
||||
return query, key
|
||||
|
||||
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
|
||||
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
|
||||
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward = vllm_ds_rope_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward_cuda = vllm_ds_rope_forward_cuda
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
|
||||
|
||||
Reference in New Issue
Block a user