diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 4e44d8cf..0e43a598 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -14,7 +14,8 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream - +from typing import Optional, Tuple +from vllm_ascend.ops.triton.rope import rope_forward_triton def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: @@ -302,7 +303,15 @@ def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor, input_offset: torch.Tensor) -> torch.Tensor: return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False) - +def _rope_forward_triton_fake( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + rope_dim: int = -1, + is_neox_style: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) direct_register_custom_op(op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, @@ -369,3 +378,8 @@ direct_register_custom_op(op_name="quantize", fake_impl=_quantize_impl_fake, mutates_args=[], dispatch_key="PrivateUse1") +direct_register_custom_op(op_name="rope_forward_triton", + op_func=rope_forward_triton, + fake_impl=_rope_forward_triton_fake, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index d699ec7d..afc01d55 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -187,6 +187,7 @@ def _rope_forward_oot( self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + cos, sin = get_cos_and_sin_slice() # adopt custom kernel path for rotary_embedding if _custom_rotary_embedding_enabled( query, is_neox_style, self.head_size) and get_ascend_device_type( @@ -204,7 +205,6 @@ def _rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - cos, sin = get_cos_and_sin_slice() if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ -1] == 128 and cos is not None and sin is not None: # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. @@ -217,28 +217,43 @@ def _rope_forward_oot( query, key = torch_npu.npu_apply_rotary_pos_emb( query, key, cos, sin) elif self.rotary_dim < self.head_size: - num_tokens = query.shape[0] - query = query.view(num_tokens, -1, self.head_size) - key = key.view(num_tokens, -1, self.head_size) - q_rot = query[..., :self.rotary_dim] - q_pass = query[..., self.rotary_dim:] - k_rot = key[..., :self.rotary_dim] - k_pass = key[..., self.rotary_dim:] - q_rot = q_rot.contiguous().view(num_tokens, -1) - k_rot = k_rot.contiguous().view(num_tokens, -1) - torch_npu._npu_rotary_embedding( - positions, - q_rot, - k_rot, - self.head_size, - self.cos_sin_cache, - is_neox_style, - ) - q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) - k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) - q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) - k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) - return q, k + if HAS_TRITON: + + cos = cos.view(-1, self.rotary_dim) + sin = sin.view(-1, self.rotary_dim) + q = query.contiguous().view(query.shape[0], -1, + self.head_size) + k = key.contiguous().view(key.shape[0], -1, self.head_size) + query, key = torch.ops.vllm.rope_forward_triton(q, + k, + cos, + sin, + rope_dim=self.rotary_dim, + is_neox_style=True) + return query.view(query_shape), key.view(key_shape) + else: + num_tokens = query.shape[0] + query = query.view(num_tokens, -1, self.head_size) + key = key.view(num_tokens, -1, self.head_size) + q_rot = query[..., :self.rotary_dim] + q_pass = query[..., self.rotary_dim:] + k_rot = key[..., :self.rotary_dim] + k_pass = key[..., self.rotary_dim:] + q_rot = q_rot.contiguous().view(num_tokens, -1) + k_rot = k_rot.contiguous().view(num_tokens, -1) + torch_npu._npu_rotary_embedding( + positions, + q_rot, + k_rot, + self.head_size, + self.cos_sin_cache, + is_neox_style, + ) + q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) + k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) + q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) + return q, k else: # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index 3700e329..8eecac8f 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -15,7 +15,8 @@ # This file is a part of the vllm-ascend project. # from vllm.triton_utils import tl, triton - +import torch +from typing import Tuple from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num @@ -157,12 +158,14 @@ def _triton_rope( mask=second_k_mask) -def rope_forward_triton(q, - k, - cos, - sin, +def rope_forward_triton( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, rope_dim: int = -1, - is_neox_style: bool = True): + is_neox_style: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor]: if not q.is_contiguous(): q = q.contiguous() if not k.is_contiguous():