Add kernels to optimize RoPE and the decoding stage (#143)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -76,6 +76,24 @@ def vllm_kunlun_forward_cuda(
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def vllm_ds_rope_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return torch.ops.xspeedgate_ops.flashinfer_rotary_embedding(
|
||||
positions=positions,
|
||||
rotary_dim=self.rotary_dim,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=self.cos_sin_cache,
|
||||
is_neox_style=self.is_neox_style,
|
||||
query=query,
|
||||
key=key,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
def apply_interleaved_rope(x: torch.Tensor,
|
||||
mrope_section: list[int]) -> torch.Tensor:
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
@@ -145,12 +163,10 @@ def vllm_kunlun_mrope_forward_cuda(
|
||||
|
||||
return query, key
|
||||
|
||||
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
|
||||
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
|
||||
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward = vllm_ds_rope_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward_cuda = vllm_ds_rope_forward_cuda
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
|
||||
|
||||
Reference in New Issue
Block a user