[Ops][Refactor] Remove custom rotary_embedding operator (#6523)

### What this PR does / why we need it?
This PR removes the custom `rotary_embedding` operator and its
associated C++ kernel implementation, PyTorch bindings, and tests.

The codebase now falls back to using the native
`torch_npu._npu_rotary_embedding` implementation. This change simplifies
the codebase by removing custom, platform-specific kernel code and
relying on the standard NPU library implementation, which is presumably
more optimized and easier to maintain.

### Does this PR introduce _any_ user-facing change?
No. This is an internal refactoring and does not introduce any
user-facing changes.

### How was this patch tested?
The tests for the custom `rotary_embedding` operator have been removed
along with the operator itself. The correctness of the fallback to the
native `torch_npu` implementation is verified by existing CI tests for
attention layers and models that use rotary embeddings.

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2026-02-07 09:24:05 +08:00
committed by GitHub
parent 06aa6036f6
commit 6c49f95da2
8 changed files with 59 additions and 1392 deletions

View File

@@ -31,7 +31,7 @@ from torch.library import Library
# 3. The registration utility will check if a meta implementation already exists for your op,
# and only register if necessary. This avoids duplicate registrations.
#
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
# 4. Example meta implementations are provided below for get_masked_input_and_mask.
#
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
@@ -52,25 +52,6 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
lib.impl(op_name, fn, "Meta")
def rotary_embedding_meta(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
):
num_tokens = positions.numel()
query_hidden_size = query.numel() // num_tokens
key_hidden_size = key.numel() // num_tokens
num_heads = query_hidden_size // head_size
num_kv_heads = key_hidden_size // head_size
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
return query_dst, key_dst
def get_masked_input_and_mask_meta(
input: torch.Tensor,
org_vocab_start_index: int,
@@ -105,7 +86,6 @@ def sgmv_expand_meta(
return y_out
register_meta_if_necessary("_C_ascend", "rotary_embedding", rotary_embedding_meta)
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta)
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)