From d350c2ada6845894a9c58a63d2d3fa27713ce4a9 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Tue, 13 Jan 2026 15:47:23 +0800 Subject: [PATCH] [CustomOp][Perf] Merge Q/K split to simplify AscendApplyRotaryEmb for better performance (#5799) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? - Use upstream util function (`_pre_process()` and `_post_process()`) to reduce redundant codes. (Find more details at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding/common.py#L184-L213) - Merge Q/K split to simplify the logic of calling `torch_npu.npu_rotary_mul()` for better performance (TPOT has been reduced by **6.22%**). ### Does this PR introduce _any_ user-facing change? no. ### How was this patch tested? #### ✅ Functional test Launch the server: ```bash export VLLM_USE_MODELSCOPE=True vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 ``` Query the server: ```bash curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [ {"type": "image_url", "image_url": {"url": "https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png"}}, {"type": "text", "text": "What is the text in the illustrate? How does it look?"} ]} ], "max_tokens": 100 }' ``` Output: ``` {"id":"chatcmpl-b2911ab6989ef098","object":"chat.completion","created":1768202780,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design, with \"TONGYI\" being more prominent and \"Qwen\" being slightly smaller and positioned below it. The font style is modern and clean, with \"TONGYI\" having a slightly bolder appearance compared to \"Qwen.\"","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":178,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` #### ✅ Benchmark Run: ```bash export VLLM_USE_MODELSCOPE=False export HF_ENDPOINT="https://hf-mirror.com" vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 10 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 10 Failed requests: 0 Benchmark duration (s): 5.96 Total input tokens: 7191 Total generated tokens: 996 Request throughput (req/s): 1.68 Output token throughput (tok/s): 167.05 Peak output token throughput (tok/s): 261.00 Peak concurrent requests: 10.00 Total token throughput (tok/s): 1373.16 ---------------Time to First Token---------------- Mean TTFT (ms): 964.43 Median TTFT (ms): 858.48 P99 TTFT (ms): 1691.45 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 63.08 Median TPOT (ms): 40.86 P99 TPOT (ms): 241.30 ---------------Inter-token Latency---------------- Mean ITL (ms): 40.16 Median ITL (ms): 33.61 P99 ITL (ms): 250.30 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 10 Failed requests: 0 Benchmark duration (s): 5.71 Total input tokens: 7191 Total generated tokens: 996 Request throughput (req/s): 1.75 Output token throughput (tok/s): 174.45 Peak output token throughput (tok/s): 279.00 Peak concurrent requests: 10.00 Total token throughput (tok/s): 1433.95 ---------------Time to First Token---------------- Mean TTFT (ms): 992.14 Median TTFT (ms): 938.30 P99 TTFT (ms): 1728.71 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 59.16 Median TPOT (ms): 37.65 P99 TPOT (ms): 234.89 ---------------Inter-token Latency---------------- Mean ITL (ms): 36.55 Median ITL (ms): 30.73 P99 ITL (ms): 170.72 ================================================== ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 50 ++++++++++++++--------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 7fba3b06..d699ec7d 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -18,7 +18,6 @@ import math from typing import Optional, Tuple -import einops import torch import torch_npu from vllm.model_executor.layers.rotary_embedding import ( @@ -32,7 +31,8 @@ if HAS_TRITON: from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, - get_ascend_device_type, has_rope, is_vl_model) + get_ascend_device_type, has_rope, is_vl_model, + vllm_version_is) # Currently, rope ops used on npu requires detached cos && sin as inputs. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. @@ -622,14 +622,20 @@ class AscendApplyRotaryEmb(ApplyRotaryEmb): cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: + if vllm_version_is('0.13.0'): + origin_shape = x.shape + origin_dtype = x.dtype + if len(origin_shape) == 3: + x = x.unsqueeze(0) + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + else: + x, cos, sin, origin_shape, origin_dtype = self._pre_process( + x, cos, sin) + head_dim = x.shape[-1] - - origin_dtype = x.dtype - if self.enable_fp32_compute: - x = x.float() - cos = cos.float() - sin = sin.float() - # cos, sin: [seq_len, head_dim // 2] cos = torch.cat((cos, cos), dim=-1) sin = torch.cat((sin, sin), dim=-1) @@ -637,22 +643,14 @@ class AscendApplyRotaryEmb(ApplyRotaryEmb): cos = cos.reshape(1, -1, 1, head_dim) sin = sin.reshape(1, -1, 1, head_dim) - if len(x.shape) == 3: - # x: [seq_len, num_heads, head_size] - x = x.unsqueeze(0) - # x: [1, seq_len, num_heads, head_size] - output = torch_npu.npu_rotary_mul(x, cos, sin).squeeze(0) - else: - assert len(x.shape) == 4 - # x: [2 * b, s, head, head_dim] - qk = einops.rearrange( - x, "(two b) s head head_dim -> b s two head head_dim", two=2) - # q, k: [b, s, head, head_dim] - q, k = qk[:, :, 0], qk[:, :, 1] - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - output = torch.cat([q, k], dim=0) + output = torch_npu.npu_rotary_mul(x, cos, sin) + + if vllm_version_is('0.13.0'): + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + else: + output = self._post_process(output, origin_shape, origin_dtype) - if self.enable_fp32_compute: - output = output.to(origin_dtype) return output