adopt rope in vllm-ascend (#530)
### What this PR does / why we need it? Adopt custom kernel rotary embedding in actual model inference, customized rotary_embedding will generate contiguous query and key in the cpp side to reduce the overhead of two contiguous and index_select compared with rotary_embedding in torch_npu. For now, rotary_embedding can only support the scenario of `is_neox = true`, non-neox version rope will be updated soon in the future. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -28,9 +28,9 @@
|
||||
using vllm_ascend::AccType;
|
||||
using vllm_ascend::local_mem_copy;
|
||||
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
||||
// NOTE(ganyi): we use 32K as load stride for pipe, need to find another way to
|
||||
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
|
||||
// retrive this size from runtime for more Soc support
|
||||
static int constexpr loadSize = 1024 * 4;
|
||||
static int constexpr loadSize = 512;
|
||||
using dst_t = scalar_t;
|
||||
using acc_t = typename AccType<scalar_t>::type;
|
||||
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
||||
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
|
||||
{
|
||||
int32_t deviceId = 0;
|
||||
@@ -51,44 +51,52 @@ void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
||||
key.size(1) == positions.size(1),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.numel() / num_tokens;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
|
||||
|
||||
// Make sure query and key have consistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key_hidden_size / head_size;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
|
||||
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
|
||||
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
|
||||
void *query_dst_ptr = query_dst.data_ptr();
|
||||
void *key_dst_ptr = key_dst.data_ptr();
|
||||
void *query_ptr = query.data_ptr();
|
||||
void *key_ptr = key.data_ptr();
|
||||
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.stride(seq_dim_idx);
|
||||
int64_t dst_query_stride = query_dst.stride(0);
|
||||
int64_t dst_key_stride = key_dst.stride(0);
|
||||
at::ScalarType scalar_type = query.scalar_type();
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("rotary_embedding");
|
||||
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr,
|
||||
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
|
||||
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
|
||||
num_heads, num_kv_heads, head_size]() -> int {
|
||||
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
|
||||
auto dtype_num = get_dtype_from_torch(scalar_type);
|
||||
fe::PlatFormInfos platform_infos;
|
||||
int device_id = 0;
|
||||
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
|
||||
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
|
||||
uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum;
|
||||
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_ptr, key_ptr, query_ptr,
|
||||
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
|
||||
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
|
||||
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
|
||||
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return ;
|
||||
return {query_dst, key_dst};
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -101,7 +109,7 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
ops.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
||||
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
||||
}
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ def test_rotary_embedding_quant_with_leading_dim(
|
||||
)
|
||||
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
torch.ops._C.rotary_embedding(
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
@@ -194,11 +194,11 @@ def test_rotary_embedding_quant_with_leading_dim(
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(query,
|
||||
torch.testing.assert_close(query.view(ref_query.size()),
|
||||
ref_query,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(key,
|
||||
torch.testing.assert_close(key.view(ref_key.size()),
|
||||
ref_key,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
@@ -21,6 +21,8 @@ import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
self,
|
||||
@@ -35,14 +37,9 @@ def rope_forward_oot(
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
torch_npu._npu_rotary_embedding(
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if CUSTOM_OP_ENABLED and self.is_neox_style:
|
||||
return torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
@@ -50,30 +47,14 @@ def rope_forward_oot(
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
|
||||
def rope_deepseek_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
import torch_npu
|
||||
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
ori_query_shape, ori_key_shape = query.shape, key.shape
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
@@ -82,11 +63,8 @@ def rope_deepseek_forward_oot(
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
query = query.view(ori_query_shape)
|
||||
key = key.view(ori_key_shape)
|
||||
|
||||
return query, key
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = rope_forward_oot
|
||||
|
||||
@@ -23,7 +23,9 @@ import torch
|
||||
import torch_npu # noqa: F401
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
CUSTOM_OP_ENABLED = False
|
||||
try:
|
||||
# register custom ops into torch_library here
|
||||
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
||||
@@ -35,8 +37,8 @@ except ImportError as e:
|
||||
logging.warning(
|
||||
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
||||
)
|
||||
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
else:
|
||||
CUSTOM_OP_ENABLED = True
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
Reference in New Issue
Block a user