[CustomOp] Register AscendApplyRotaryEmb CustomOp and remove related patch (#4667)

### What this PR does / why we need it?

Following https://github.com/vllm-project/vllm/pull/29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

####  Test Qwen2.5-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ  ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio":   null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok    en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```

####  Test Qwen3-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-12-23 10:04:37 +08:00
committed by GitHub
parent 35dbdbb398
commit 6c478531f8
7 changed files with 71 additions and 260 deletions

View File

@@ -18,12 +18,14 @@
import math
from typing import Optional, Tuple
import einops
import torch
import torch_npu
from vllm.config import CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
@@ -524,3 +526,59 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
rotary_mode='half')
return query, key
class AscendApplyRotaryEmb(ApplyRotaryEmb):
def __init__(
self,
enforce_enable: bool = False,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> None:
super().__init__(
enforce_enable=enforce_enable,
is_neox_style=is_neox_style,
enable_fp32_compute=enable_fp32_compute,
)
def forward_oot(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
head_dim = x.shape[-1]
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
# cos, sin: [seq_len, head_dim // 2]
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
# cos, sin: [1, seq_len, 1, head_dim]
cos = cos.reshape(1, -1, 1, head_dim)
sin = sin.reshape(1, -1, 1, head_dim)
if len(x.shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
# x: [1, seq_len, num_heads, head_size]
output = torch_npu.npu_rotary_mul(x, cos, sin).squeeze(0)
else:
assert len(x.shape) == 4
# x: [2 * b, s, head, head_dim]
qk = einops.rearrange(
x, "(two b) s head head_dim -> b s two head head_dim", two=2)
# q, k: [b, s, head, head_dim]
q, k = qk[:, :, 0], qk[:, :, 1]
q = torch_npu.npu_rotary_mul(q, cos, sin)
k = torch_npu.npu_rotary_mul(k, cos, sin)
output = torch.cat([q, k], dim=0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
return output