[Ops][Refactor] Remove custom rotary_embedding operator (#6523)
### What this PR does / why we need it? This PR removes the custom `rotary_embedding` operator and its associated C++ kernel implementation, PyTorch bindings, and tests. The codebase now falls back to using the native `torch_npu._npu_rotary_embedding` implementation. This change simplifies the codebase by removing custom, platform-specific kernel code and relying on the standard NPU library implementation, which is presumably more optimized and easier to maintain. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring and does not introduce any user-facing changes. ### How was this patch tested? The tests for the custom `rotary_embedding` operator have been removed along with the operator itself. The correctness of the fallback to the native `torch_npu` implementation is verified by existing CI tests for attention layers and models that use rotary embeddings. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -31,7 +31,7 @@ from torch.library import Library
|
||||
# 3. The registration utility will check if a meta implementation already exists for your op,
|
||||
# and only register if necessary. This avoids duplicate registrations.
|
||||
#
|
||||
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
|
||||
# 4. Example meta implementations are provided below for get_masked_input_and_mask.
|
||||
#
|
||||
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
|
||||
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
|
||||
@@ -52,25 +52,6 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
|
||||
lib.impl(op_name, fn, "Meta")
|
||||
|
||||
|
||||
def rotary_embedding_meta(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
):
|
||||
num_tokens = positions.numel()
|
||||
query_hidden_size = query.numel() // num_tokens
|
||||
key_hidden_size = key.numel() // num_tokens
|
||||
num_heads = query_hidden_size // head_size
|
||||
num_kv_heads = key_hidden_size // head_size
|
||||
|
||||
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
|
||||
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
|
||||
return query_dst, key_dst
|
||||
|
||||
|
||||
def get_masked_input_and_mask_meta(
|
||||
input: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
@@ -105,7 +86,6 @@ def sgmv_expand_meta(
|
||||
return y_out
|
||||
|
||||
|
||||
register_meta_if_necessary("_C_ascend", "rotary_embedding", rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta)
|
||||
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
|
||||
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)
|
||||
|
||||
Reference in New Issue
Block a user