adopt rope in vllm-ascend (#530)
### What this PR does / why we need it? Adopt custom kernel rotary embedding in actual model inference, customized rotary_embedding will generate contiguous query and key in the cpp side to reduce the overhead of two contiguous and index_select compared with rotary_embedding in torch_npu. For now, rotary_embedding can only support the scenario of `is_neox = true`, non-neox version rope will be updated soon in the future. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -28,9 +28,9 @@
|
|||||||
using vllm_ascend::AccType;
|
using vllm_ascend::AccType;
|
||||||
using vllm_ascend::local_mem_copy;
|
using vllm_ascend::local_mem_copy;
|
||||||
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
||||||
// NOTE(ganyi): we use 32K as load stride for pipe, need to find another way to
|
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
|
||||||
// retrive this size from runtime for more Soc support
|
// retrive this size from runtime for more Soc support
|
||||||
static int constexpr loadSize = 1024 * 4;
|
static int constexpr loadSize = 512;
|
||||||
using dst_t = scalar_t;
|
using dst_t = scalar_t;
|
||||||
using acc_t = typename AccType<scalar_t>::type;
|
using acc_t = typename AccType<scalar_t>::type;
|
||||||
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
|
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
|
||||||
|
|||||||
@@ -29,7 +29,7 @@
|
|||||||
|
|
||||||
namespace vllm_ascend {
|
namespace vllm_ascend {
|
||||||
|
|
||||||
void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
||||||
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
|
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
|
||||||
{
|
{
|
||||||
int32_t deviceId = 0;
|
int32_t deviceId = 0;
|
||||||
@@ -51,44 +51,52 @@ void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
|||||||
key.size(1) == positions.size(1),
|
key.size(1) == positions.size(1),
|
||||||
"query, key and positions must have the same batch_size and seq_len");
|
"query, key and positions must have the same batch_size and seq_len");
|
||||||
}
|
}
|
||||||
|
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
|
||||||
int query_hidden_size = query.numel() / num_tokens;
|
int query_hidden_size = query.numel() / num_tokens;
|
||||||
int key_hidden_size = key.numel() / num_tokens;
|
int key_hidden_size = key.numel() / num_tokens;
|
||||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||||
|
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
|
||||||
|
|
||||||
// Make sure query and key have consistent number of heads
|
// Make sure query and key have consistent number of heads
|
||||||
int num_heads = query_hidden_size / head_size;
|
int num_heads = query_hidden_size / head_size;
|
||||||
int num_kv_heads = key_hidden_size / head_size;
|
int num_kv_heads = key_hidden_size / head_size;
|
||||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||||
|
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
|
||||||
|
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
|
||||||
|
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
|
int seq_dim_idx = positions_ndim - 1;
|
||||||
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
|
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
|
||||||
|
void *query_dst_ptr = query_dst.data_ptr();
|
||||||
|
void *key_dst_ptr = key_dst.data_ptr();
|
||||||
void *query_ptr = query.data_ptr();
|
void *query_ptr = query.data_ptr();
|
||||||
void *key_ptr = key.data_ptr();
|
void *key_ptr = key.data_ptr();
|
||||||
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
|
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
|
||||||
int64_t query_stride = query.stride(-2);
|
int64_t query_stride = query.stride(seq_dim_idx);
|
||||||
int64_t key_stride = key.stride(-2);
|
int64_t key_stride = key.stride(seq_dim_idx);
|
||||||
|
int64_t dst_query_stride = query_dst.stride(0);
|
||||||
|
int64_t dst_key_stride = key_dst.stride(0);
|
||||||
at::ScalarType scalar_type = query.scalar_type();
|
at::ScalarType scalar_type = query.scalar_type();
|
||||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||||
at_npu::native::OpCommand cmd;
|
at_npu::native::OpCommand cmd;
|
||||||
cmd.Name("rotary_embedding");
|
cmd.Name("rotary_embedding");
|
||||||
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr,
|
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
|
||||||
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
|
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
|
||||||
num_heads, num_kv_heads, head_size]() -> int {
|
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
|
||||||
auto dtype_num = get_dtype_from_torch(scalar_type);
|
auto dtype_num = get_dtype_from_torch(scalar_type);
|
||||||
fe::PlatFormInfos platform_infos;
|
fe::PlatFormInfos platform_infos;
|
||||||
int device_id = 0;
|
int device_id = 0;
|
||||||
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
|
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
|
||||||
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
|
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
|
||||||
uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum;
|
uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum;
|
||||||
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_ptr, key_ptr, query_ptr,
|
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
|
||||||
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, query_stride,
|
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
|
||||||
key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
|
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
|
||||||
return 0;
|
return 0;
|
||||||
});
|
});
|
||||||
cmd.Run();
|
cmd.Run();
|
||||||
return ;
|
return {query_dst, key_dst};
|
||||||
}
|
}
|
||||||
} // namespace vllm_ascend
|
} // namespace vllm_ascend
|
||||||
|
|
||||||
@@ -101,7 +109,7 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
|||||||
ops.def(
|
ops.def(
|
||||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||||
" Tensor! key, int head_size,"
|
" Tensor! key, int head_size,"
|
||||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
||||||
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ def test_rotary_embedding_quant_with_leading_dim(
|
|||||||
)
|
)
|
||||||
|
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||||
torch.ops._C.rotary_embedding(
|
query, key = torch.ops._C.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -194,11 +194,11 @@ def test_rotary_embedding_quant_with_leading_dim(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
torch.testing.assert_close(query,
|
torch.testing.assert_close(query.view(ref_query.size()),
|
||||||
ref_query,
|
ref_query,
|
||||||
atol=DEFAULT_ATOL,
|
atol=DEFAULT_ATOL,
|
||||||
rtol=DEFAULT_RTOL)
|
rtol=DEFAULT_RTOL)
|
||||||
torch.testing.assert_close(key,
|
torch.testing.assert_close(key.view(ref_key.size()),
|
||||||
ref_key,
|
ref_key,
|
||||||
atol=DEFAULT_ATOL,
|
atol=DEFAULT_ATOL,
|
||||||
rtol=DEFAULT_RTOL)
|
rtol=DEFAULT_RTOL)
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ import torch
|
|||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
|
|
||||||
|
from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
||||||
|
|
||||||
|
|
||||||
def rope_forward_oot(
|
def rope_forward_oot(
|
||||||
self,
|
self,
|
||||||
@@ -35,14 +37,9 @@ def rope_forward_oot(
|
|||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||||
if self.cos_sin_cache.dtype != query.dtype:
|
if self.cos_sin_cache.dtype != query.dtype:
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||||
if offsets is not None:
|
# adopt custom kernel path for rotary_embedding
|
||||||
raise NotImplementedError(
|
if CUSTOM_OP_ENABLED and self.is_neox_style:
|
||||||
"Batched rotary embedding is currently not supported on NPU.")
|
return torch.ops._C.rotary_embedding(
|
||||||
else:
|
|
||||||
# TODO: Remove the contiguous in the future.
|
|
||||||
query = query.contiguous()
|
|
||||||
key = key.contiguous()
|
|
||||||
torch_npu._npu_rotary_embedding(
|
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -50,30 +47,14 @@ def rope_forward_oot(
|
|||||||
self.cos_sin_cache,
|
self.cos_sin_cache,
|
||||||
self.is_neox_style,
|
self.is_neox_style,
|
||||||
)
|
)
|
||||||
return query, key
|
|
||||||
|
|
||||||
|
|
||||||
def rope_deepseek_forward_oot(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
if self.cos_sin_cache.device != query.device:
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
|
||||||
if self.cos_sin_cache.dtype != query.dtype:
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
|
||||||
if offsets is not None:
|
if offsets is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Batched rotary embedding is currently not supported on NPU.")
|
"Batched rotary embedding is currently not supported on NPU.")
|
||||||
else:
|
else:
|
||||||
# TODO: Remove the contiguous in the future.
|
# TODO: Remove the contiguous in the future.
|
||||||
ori_query_shape, ori_key_shape = query.shape, key.shape
|
query_shape, key_shape = query.shape, key.shape
|
||||||
query = query.contiguous().view(query.shape[0], -1)
|
query = query.contiguous().view(query.shape[0], -1)
|
||||||
key = key.contiguous().view(query.shape[0], -1)
|
key = key.contiguous().view(key.shape[0], -1)
|
||||||
torch_npu._npu_rotary_embedding(
|
torch_npu._npu_rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
@@ -82,11 +63,8 @@ def rope_deepseek_forward_oot(
|
|||||||
self.cos_sin_cache,
|
self.cos_sin_cache,
|
||||||
self.is_neox_style,
|
self.is_neox_style,
|
||||||
)
|
)
|
||||||
query = query.view(ori_query_shape)
|
return query.view(query_shape), key.view(key_shape)
|
||||||
key = key.view(ori_key_shape)
|
|
||||||
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
|
|
||||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||||
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
|
DeepseekScalingRotaryEmbedding.forward = rope_forward_oot
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ import torch
|
|||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
from vllm.platforms import Platform, PlatformEnum
|
||||||
|
|
||||||
|
CUSTOM_OP_ENABLED = False
|
||||||
try:
|
try:
|
||||||
# register custom ops into torch_library here
|
# register custom ops into torch_library here
|
||||||
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
||||||
@@ -35,8 +37,8 @@ except ImportError as e:
|
|||||||
logging.warning(
|
logging.warning(
|
||||||
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
from vllm.platforms import Platform, PlatformEnum
|
CUSTOM_OP_ENABLED = True
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
|||||||
Reference in New Issue
Block a user