[Ops][Refactor] Remove custom rotary_embedding operator (#6523)
### What this PR does / why we need it? This PR removes the custom `rotary_embedding` operator and its associated C++ kernel implementation, PyTorch bindings, and tests. The codebase now falls back to using the native `torch_npu._npu_rotary_embedding` implementation. This change simplifies the codebase by removing custom, platform-specific kernel code and relying on the standard NPU library implementation, which is presumably more optimized and easier to maintain. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring and does not introduce any user-facing changes. ### How was this patch tested? The tests for the custom `rotary_embedding` operator have been removed along with the operator itself. The correctness of the fallback to the native `torch_npu` implementation is verified by existing CI tests for attention layers and models that use rotary embeddings. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -33,7 +33,7 @@ if HAS_TRITON:
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import AscendDeviceType, enable_custom_op, get_ascend_device_type, has_rope, is_vl_model
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
||||
|
||||
# Currently, rope ops used on npu requires detached cos && sin as inputs.
|
||||
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
|
||||
@@ -144,10 +144,6 @@ def get_cos_and_sin_slice():
|
||||
return _cos_slice, _sin_slice
|
||||
|
||||
|
||||
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op()
|
||||
|
||||
|
||||
def _rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -162,9 +158,62 @@ def _rope_forward_oot(
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if _custom_rotary_embedding_enabled(query, is_neox_style, self.head_size):
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
if offsets is not None:
|
||||
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
|
||||
if (
|
||||
is_neox_style
|
||||
and self.head_size == 128
|
||||
and self.cos_sin_cache.shape[-1] == 128
|
||||
and cos is not None
|
||||
and sin is not None
|
||||
):
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
|
||||
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
||||
# Although this function modifies in-place, please retain the function's return value.
|
||||
# Otherwise, the graph fusion operation may fail.
|
||||
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
if HAS_TRITON:
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
q = query.contiguous().view(query.shape[0], -1, self.head_size)
|
||||
k = key.contiguous().view(key.shape[0], -1, self.head_size)
|
||||
query, key = torch.ops.vllm.rope_forward_triton(
|
||||
q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
else:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
q_rot = query[..., : self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim :]
|
||||
k_rot = key[..., : self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim :]
|
||||
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||
# only the rotary part is processed here,
|
||||
# the dimension should be rotary_dim
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
q_rot,
|
||||
k_rot,
|
||||
self.rotary_dim,
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
@@ -172,72 +221,7 @@ def _rope_forward_oot(
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if (
|
||||
is_neox_style
|
||||
and self.head_size == 128
|
||||
and self.cos_sin_cache.shape[-1] == 128
|
||||
and cos is not None
|
||||
and sin is not None
|
||||
):
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
|
||||
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
||||
# Although this function modifies in-place, please retain the function's return value.
|
||||
# Otherwise, the graph fusion operation may fail.
|
||||
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
if HAS_TRITON:
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
q = query.contiguous().view(query.shape[0], -1, self.head_size)
|
||||
k = key.contiguous().view(key.shape[0], -1, self.head_size)
|
||||
query, key = torch.ops.vllm.rope_forward_triton(
|
||||
q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
else:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
q_rot = query[..., : self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim :]
|
||||
k_rot = key[..., : self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim :]
|
||||
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||
# only the rotary part is processed here,
|
||||
# the dimension should be rotary_dim
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
q_rot,
|
||||
k_rot,
|
||||
self.rotary_dim,
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
Reference in New Issue
Block a user