adopt rope in vllm-ascend (#530)

### What this PR does / why we need it?
Adopt custom kernel rotary embedding in actual model inference,
customized rotary_embedding will generate contiguous query and key in
the cpp side to reduce the overhead of two contiguous and index_select
compared with rotary_embedding in torch_npu. For now, rotary_embedding
can only support the scenario of `is_neox = true`, non-neox version rope
will be updated soon in the future.
---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-04-18 08:56:05 +08:00
committed by GitHub
parent 23f85e3f74
commit 66a0837963
5 changed files with 37 additions and 49 deletions

View File

@@ -28,9 +28,9 @@
using vllm_ascend::AccType;
using vllm_ascend::local_mem_copy;
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
// NOTE(ganyi): we use 32K as load stride for pipe, need to find another way to
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
// retrive this size from runtime for more Soc support
static int constexpr loadSize = 1024 * 4;
static int constexpr loadSize = 512;
using dst_t = scalar_t;
using acc_t = typename AccType<scalar_t>::type;
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half

View File

@@ -29,7 +29,7 @@
namespace vllm_ascend {
void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
{
int32_t deviceId = 0;
@@ -51,44 +51,52 @@ void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
int rot_dim = cos_sin_cache.size(1);
int seq_dim_idx = positions_ndim - 1;
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
void *query_dst_ptr = query_dst.data_ptr();
void *key_dst_ptr = key_dst.data_ptr();
void *query_ptr = query.data_ptr();
void *key_ptr = key.data_ptr();
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
int64_t dst_query_stride = query_dst.stride(0);
int64_t dst_key_stride = key_dst.stride(0);
at::ScalarType scalar_type = query.scalar_type();
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("rotary_embedding");
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr,
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
num_heads, num_kv_heads, head_size]() -> int {
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
auto dtype_num = get_dtype_from_torch(scalar_type);
fe::PlatFormInfos platform_infos;
int device_id = 0;
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum;
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_ptr, key_ptr, query_ptr,
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
return 0;
});
cmd.Run();
return ;
return {query_dst, key_dst};
}
} // namespace vllm_ascend
@@ -101,7 +109,7 @@ TORCH_LIBRARY_EXPAND(_C, ops)
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
}

View File

@@ -184,7 +184,7 @@ def test_rotary_embedding_quant_with_leading_dim(
)
ref_query, ref_key = rope.forward_native(positions, query, key)
torch.ops._C.rotary_embedding(
query, key = torch.ops._C.rotary_embedding(
positions,
query,
key,
@@ -194,11 +194,11 @@ def test_rotary_embedding_quant_with_leading_dim(
)
# Compare the results.
torch.testing.assert_close(query,
torch.testing.assert_close(query.view(ref_query.size()),
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key,
torch.testing.assert_close(key.view(ref_key.size()),
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

View File

@@ -21,6 +21,8 @@ import torch
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.platform import CUSTOM_OP_ENABLED
def rope_forward_oot(
self,
@@ -35,14 +37,9 @@ def rope_forward_oot(
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous()
key = key.contiguous()
torch_npu._npu_rotary_embedding(
# adopt custom kernel path for rotary_embedding
if CUSTOM_OP_ENABLED and self.is_neox_style:
return torch.ops._C.rotary_embedding(
positions,
query,
key,
@@ -50,30 +47,14 @@ def rope_forward_oot(
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def rope_deepseek_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
ori_query_shape, ori_key_shape = query.shape, key.shape
query_shape, key_shape = query.shape, key.shape
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
@@ -82,11 +63,8 @@ def rope_deepseek_forward_oot(
self.cos_sin_cache,
self.is_neox_style,
)
query = query.view(ori_query_shape)
key = key.view(ori_key_shape)
return query, key
return query.view(query_shape), key.view(key_shape)
RotaryEmbedding.forward_oot = rope_forward_oot
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
DeepseekScalingRotaryEmbedding.forward = rope_forward_oot

View File

@@ -23,7 +23,9 @@ import torch
import torch_npu # noqa: F401
import vllm.envs as envs
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
CUSTOM_OP_ENABLED = False
try:
# register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
@@ -35,8 +37,8 @@ except ImportError as e:
logging.warning(
"Warning: Failed to register custom ops, all custom ops will be disabled"
)
from vllm.platforms import Platform, PlatformEnum
else:
CUSTOM_OP_ENABLED = True
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig