adopt rope in vllm-ascend (#530)
### What this PR does / why we need it? Adopt custom kernel rotary embedding in actual model inference, customized rotary_embedding will generate contiguous query and key in the cpp side to reduce the overhead of two contiguous and index_select compared with rotary_embedding in torch_npu. For now, rotary_embedding can only support the scenario of `is_neox = true`, non-neox version rope will be updated soon in the future. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -21,6 +21,8 @@ import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
self,
|
||||
@@ -35,14 +37,9 @@ def rope_forward_oot(
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
torch_npu._npu_rotary_embedding(
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if CUSTOM_OP_ENABLED and self.is_neox_style:
|
||||
return torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
@@ -50,30 +47,14 @@ def rope_forward_oot(
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
|
||||
def rope_deepseek_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
import torch_npu
|
||||
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
ori_query_shape, ori_key_shape = query.shape, key.shape
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
@@ -82,11 +63,8 @@ def rope_deepseek_forward_oot(
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
query = query.view(ori_query_shape)
|
||||
key = key.view(ori_key_shape)
|
||||
|
||||
return query, key
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = rope_forward_oot
|
||||
|
||||
Reference in New Issue
Block a user