From 66a0837963ff5dd6734907083ca3cf57e6bb223b Mon Sep 17 00:00:00 2001 From: Pleaplusone <38376071+ganyi1996ppo@users.noreply.github.com> Date: Fri, 18 Apr 2025 08:56:05 +0800 Subject: [PATCH] adopt rope in vllm-ascend (#530) ### What this PR does / why we need it? Adopt custom kernel rotary embedding in actual model inference, customized rotary_embedding will generate contiguous query and key in the cpp side to reduce the overhead of two contiguous and index_select compared with rotary_embedding in torch_npu. For now, rotary_embedding can only support the scenario of `is_neox = true`, non-neox version rope will be updated soon in the future. --------- Signed-off-by: ganyi --- csrc/kernels/pos_encoding_kernels.cpp | 4 +-- csrc/torch_binding.cpp | 30 ++++++++++++-------- tests/ops/test_rotary_embedding.py | 6 ++-- vllm_ascend/ops/rotary_embedding.py | 40 ++++++--------------------- vllm_ascend/platform.py | 6 ++-- 5 files changed, 37 insertions(+), 49 deletions(-) diff --git a/csrc/kernels/pos_encoding_kernels.cpp b/csrc/kernels/pos_encoding_kernels.cpp index cce08ca..28ef503 100644 --- a/csrc/kernels/pos_encoding_kernels.cpp +++ b/csrc/kernels/pos_encoding_kernels.cpp @@ -28,9 +28,9 @@ using vllm_ascend::AccType; using vllm_ascend::local_mem_copy; template class RotaryEmbedding { - // NOTE(ganyi): we use 32K as load stride for pipe, need to find another way to + // NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to // retrive this size from runtime for more Soc support - static int constexpr loadSize = 1024 * 4; + static int constexpr loadSize = 512; using dst_t = scalar_t; using acc_t = typename AccType::type; // only half tensor have cast instruct to int8, hardcode acc_dst_t as half diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index a4dc3a3..b874a43 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -29,7 +29,7 @@ namespace vllm_ascend { -void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key, +std::tuple rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key, int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { int32_t deviceId = 0; @@ -51,44 +51,52 @@ void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key, key.size(1) == positions.size(1), "query, key and positions must have the same batch_size and seq_len"); } - + TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32"); int query_hidden_size = query.numel() / num_tokens; int key_hidden_size = key.numel() / num_tokens; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); + TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend"); // Make sure query and key have consistent number of heads int num_heads = query_hidden_size / head_size; int num_kv_heads = key_hidden_size / head_size; TORCH_CHECK(num_heads % num_kv_heads == 0); + at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options()); + at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options()); int rot_dim = cos_sin_cache.size(1); + int seq_dim_idx = positions_ndim - 1; int64_t *position_ids_ptr = positions.data_ptr(); + void *query_dst_ptr = query_dst.data_ptr(); + void *key_dst_ptr = key_dst.data_ptr(); void *query_ptr = query.data_ptr(); void *key_ptr = key.data_ptr(); void *cos_sin_cache_ptr = cos_sin_cache.data_ptr(); - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.stride(seq_dim_idx); + int64_t dst_query_stride = query_dst.stride(0); + int64_t dst_key_stride = key_dst.stride(0); at::ScalarType scalar_type = query.scalar_type(); aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("rotary_embedding"); - cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, + cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, - num_heads, num_kv_heads, head_size]() -> int { + dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int { auto dtype_num = get_dtype_from_torch(scalar_type); fe::PlatFormInfos platform_infos; int device_id = 0; fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos); uint32_t aivNum = platform_infos.GetCoreNumByType("aiv"); uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum; - rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_ptr, key_ptr, query_ptr, - key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, query_stride, - key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum); + rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr, + key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride, + dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum); return 0; }); cmd.Run(); - return ; + return {query_dst, key_dst}; } } // namespace vllm_ascend @@ -101,7 +109,7 @@ TORCH_LIBRARY_EXPAND(_C, ops) ops.def( "rotary_embedding(Tensor positions, Tensor! query," " Tensor! key, int head_size," - " Tensor cos_sin_cache, bool is_neox) -> ()"); + " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)"); ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding); } diff --git a/tests/ops/test_rotary_embedding.py b/tests/ops/test_rotary_embedding.py index 3e48fe1..800960b 100644 --- a/tests/ops/test_rotary_embedding.py +++ b/tests/ops/test_rotary_embedding.py @@ -184,7 +184,7 @@ def test_rotary_embedding_quant_with_leading_dim( ) ref_query, ref_key = rope.forward_native(positions, query, key) - torch.ops._C.rotary_embedding( + query, key = torch.ops._C.rotary_embedding( positions, query, key, @@ -194,11 +194,11 @@ def test_rotary_embedding_quant_with_leading_dim( ) # Compare the results. - torch.testing.assert_close(query, + torch.testing.assert_close(query.view(ref_query.size()), ref_query, atol=DEFAULT_ATOL, rtol=DEFAULT_RTOL) - torch.testing.assert_close(key, + torch.testing.assert_close(key.view(ref_key.size()), ref_key, atol=DEFAULT_ATOL, rtol=DEFAULT_RTOL) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 32e660f..5fc30ee 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -21,6 +21,8 @@ import torch from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm_ascend.platform import CUSTOM_OP_ENABLED + def rope_forward_oot( self, @@ -35,14 +37,9 @@ def rope_forward_oot( self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) - if offsets is not None: - raise NotImplementedError( - "Batched rotary embedding is currently not supported on NPU.") - else: - # TODO: Remove the contiguous in the future. - query = query.contiguous() - key = key.contiguous() - torch_npu._npu_rotary_embedding( + # adopt custom kernel path for rotary_embedding + if CUSTOM_OP_ENABLED and self.is_neox_style: + return torch.ops._C.rotary_embedding( positions, query, key, @@ -50,30 +47,14 @@ def rope_forward_oot( self.cos_sin_cache, self.is_neox_style, ) - return query, key - - -def rope_deepseek_forward_oot( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - import torch_npu - - if self.cos_sin_cache.device != query.device: - self.cos_sin_cache = self.cos_sin_cache.to(query.device) - if self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) if offsets is not None: raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: # TODO: Remove the contiguous in the future. - ori_query_shape, ori_key_shape = query.shape, key.shape + query_shape, key_shape = query.shape, key.shape query = query.contiguous().view(query.shape[0], -1) - key = key.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) torch_npu._npu_rotary_embedding( positions, query, @@ -82,11 +63,8 @@ def rope_deepseek_forward_oot( self.cos_sin_cache, self.is_neox_style, ) - query = query.view(ori_query_shape) - key = key.view(ori_key_shape) - - return query, key + return query.view(query_shape), key.view(key_shape) RotaryEmbedding.forward_oot = rope_forward_oot -DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot +DeepseekScalingRotaryEmbedding.forward = rope_forward_oot diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ff35808..5a09905 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -23,7 +23,9 @@ import torch import torch_npu # noqa: F401 import vllm.envs as envs from vllm.logger import logger +from vllm.platforms import Platform, PlatformEnum +CUSTOM_OP_ENABLED = False try: # register custom ops into torch_library here import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 @@ -35,8 +37,8 @@ except ImportError as e: logging.warning( "Warning: Failed to register custom ops, all custom ops will be disabled" ) - -from vllm.platforms import Platform, PlatformEnum + else: + CUSTOM_OP_ENABLED = True if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig