Files
xc-llm-ascend/vllm_ascend/ops/mm_encoder_attention.py

180 lines
6.7 KiB
Python
Raw Permalink Normal View History

[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import einops
[MM][Perf] Pre-compute `seq_lens` and put it on CPU before ViT vision blocks for better performance (#7104) ### What this PR does / why we need it? **Background:** PR https://github.com/vllm-project/vllm-ascend/pull/6448 has introduced a `seq_lens` CPU cache mechanism, which will considerably benefit the performance for VL models but may lead to accuracy issues. Thus, we have reverted it. **Proposed Change:** In PR https://github.com/vllm-project/vllm/pull/36605, we have supported custom processing logic for OOT MMEncoder kernels in vLLM. Thus, we can pre-compute `seq_lens` (rather than `cu_seqlens`) and put it on CPU before ViT vision blocks to avoid redundant computation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? #### ✅ Functional Test Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" having a slightly bolder and more prominent appearance compared to \"Qwen.\" The overall design is simple and professional." ``` > [!NOTE] > Since PR https://github.com/vllm-project/vllm/pull/36605 only modified `Qwen3-VL` modeling files, this PR has no affect to `Qwen2.5-VL` model. --- Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with TONG." ``` --- #### ✅ Benchmark Launch the server: ``` vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 ``` Run benchmark: ``` vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 78.58 Total input tokens: 33418 Total generated tokens: 61431 Request throughput (req/s): 6.36 Output token throughput (tok/s): 781.78 Peak output token throughput (tok/s): 2475.00 Peak concurrent requests: 383.00 Total token throughput (tok/s): 1207.07 ---------------Time to First Token---------------- Mean TTFT (ms): 7116.24 Median TTFT (ms): 4295.84 P99 TTFT (ms): 18370.87 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 245.78 Median TPOT (ms): 264.03 P99 TPOT (ms): 334.38 ---------------Inter-token Latency---------------- Mean ITL (ms): 246.99 Median ITL (ms): 117.71 P99 ITL (ms): 1327.55 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 77.44 Total input tokens: 33418 Total generated tokens: 61522 Request throughput (req/s): 6.46 Output token throughput (tok/s): 794.40 Peak output token throughput (tok/s): 2691.00 Peak concurrent requests: 369.00 Total token throughput (tok/s): 1225.91 ---------------Time to First Token---------------- Mean TTFT (ms): 6888.64 Median TTFT (ms): 4128.82 P99 TTFT (ms): 17487.94 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 240.14 Median TPOT (ms): 259.18 P99 TPOT (ms): 313.15 ---------------Inter-token Latency---------------- Mean ITL (ms): 241.84 Median ITL (ms): 121.08 P99 ITL (ms): 1470.33 ================================================== ``` **Performance Metrics:** | Metric | Before this PR | After this PR | Comparison | | :----- | :------------- | :------------ | :--------- | | **Throughput** | | | | | Request throughput (req/s) | 6.36 | 6.46 | +1.57% ↑ | | Output token throughput (tok/s) | 781.78 | 794.40 | +1.61% ↑ | | Total token throughput (tok/s) | 1,207.07 | 1,225.91 | +1.56% ↑ | | Peak output token throughput (tok/s) | 2,475 | 2,691 | +8.73% ↑ | | **Latency** | | | | | Benchmark duration (s) | 78.58 | 77.44 | -1.45% ↓ | | Mean TTFT (ms) | 7,116.24 | 6,888.64 | -3.20% ↓ | | Median TTFT (ms) | 4,295.84 | 4,128.82 | -3.89% ↓ | | P99 TTFT (ms) | 18,370.87 | 17,487.94 | -4.81% ↓ | | Mean TPOT (ms) | 245.78 | 240.14 | -2.29% ↓ | | Median TPOT (ms) | 264.03 | 259.18 | -1.84% ↓ | | P99 TPOT (ms) | 334.38 | 313.15 | -6.35% ↓ | | Mean ITL (ms) | 246.99 | 241.84 | -2.09% ↓ | | Median ITL (ms) | 117.71 | 121.08 | +2.86% ↑ | | P99 ITL (ms) | 1,327.55 | 1,470.33 | +10.76% ↑ | **🤖 AI Summary:** - The most notable improvement is in P99 TPOT, which dropped **-6.35%** from 334.38ms → 313.15ms, indicating reduced tail latency for per-token generation under heavy load. - TTFT improved across all percentiles: mean dropped **-3.20%** (7,116ms → 6,889ms), median **-3.89%** (4,296ms → 4,129ms), and P99 **-4.81%** (18,371ms → 17,488ms), reflecting faster time-to-first-token across the board. - TPOT also improved consistently, with mean down **-2.29%** (245.78ms → 240.14ms) and median down **-1.84%** (264.03ms → 259.18ms), showing a modest but steady reduction in per-token generation time. - Throughput saw a slight uplift of roughly **+1.6%** across request, output token, and total token throughput. Peak output token throughput jumped **+8.73%** (2,475 → 2,691 tok/s), suggesting better burst handling capacity. - P99 ITL increased **+10.76%** (1,328ms → 1,470ms), the largest regression in the run. Median ITL also ticked up **+2.86%** (117.71ms → 121.08ms). These tail-latency spikes may reflect scheduling variability under peak concurrency and could be within run-to-run noise, but are worth monitoring. - Overall, the PR delivers a consistent improvement in both throughput and latency, with the caveat that P99 inter-token latency regressed — likely a transient effect given that mean ITL still improved by **-2.09%**. --- - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: shen-shanshan <467638484@qq.com>
2026-03-23 15:24:26 +08:00
import numpy as np
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
import torch
import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore
[MM][Perf] Pre-compute `seq_lens` and put it on CPU before ViT vision blocks for better performance (#7104) ### What this PR does / why we need it? **Background:** PR https://github.com/vllm-project/vllm-ascend/pull/6448 has introduced a `seq_lens` CPU cache mechanism, which will considerably benefit the performance for VL models but may lead to accuracy issues. Thus, we have reverted it. **Proposed Change:** In PR https://github.com/vllm-project/vllm/pull/36605, we have supported custom processing logic for OOT MMEncoder kernels in vLLM. Thus, we can pre-compute `seq_lens` (rather than `cu_seqlens`) and put it on CPU before ViT vision blocks to avoid redundant computation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? #### ✅ Functional Test Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" having a slightly bolder and more prominent appearance compared to \"Qwen.\" The overall design is simple and professional." ``` > [!NOTE] > Since PR https://github.com/vllm-project/vllm/pull/36605 only modified `Qwen3-VL` modeling files, this PR has no affect to `Qwen2.5-VL` model. --- Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with TONG." ``` --- #### ✅ Benchmark Launch the server: ``` vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 ``` Run benchmark: ``` vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 78.58 Total input tokens: 33418 Total generated tokens: 61431 Request throughput (req/s): 6.36 Output token throughput (tok/s): 781.78 Peak output token throughput (tok/s): 2475.00 Peak concurrent requests: 383.00 Total token throughput (tok/s): 1207.07 ---------------Time to First Token---------------- Mean TTFT (ms): 7116.24 Median TTFT (ms): 4295.84 P99 TTFT (ms): 18370.87 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 245.78 Median TPOT (ms): 264.03 P99 TPOT (ms): 334.38 ---------------Inter-token Latency---------------- Mean ITL (ms): 246.99 Median ITL (ms): 117.71 P99 ITL (ms): 1327.55 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 77.44 Total input tokens: 33418 Total generated tokens: 61522 Request throughput (req/s): 6.46 Output token throughput (tok/s): 794.40 Peak output token throughput (tok/s): 2691.00 Peak concurrent requests: 369.00 Total token throughput (tok/s): 1225.91 ---------------Time to First Token---------------- Mean TTFT (ms): 6888.64 Median TTFT (ms): 4128.82 P99 TTFT (ms): 17487.94 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 240.14 Median TPOT (ms): 259.18 P99 TPOT (ms): 313.15 ---------------Inter-token Latency---------------- Mean ITL (ms): 241.84 Median ITL (ms): 121.08 P99 ITL (ms): 1470.33 ================================================== ``` **Performance Metrics:** | Metric | Before this PR | After this PR | Comparison | | :----- | :------------- | :------------ | :--------- | | **Throughput** | | | | | Request throughput (req/s) | 6.36 | 6.46 | +1.57% ↑ | | Output token throughput (tok/s) | 781.78 | 794.40 | +1.61% ↑ | | Total token throughput (tok/s) | 1,207.07 | 1,225.91 | +1.56% ↑ | | Peak output token throughput (tok/s) | 2,475 | 2,691 | +8.73% ↑ | | **Latency** | | | | | Benchmark duration (s) | 78.58 | 77.44 | -1.45% ↓ | | Mean TTFT (ms) | 7,116.24 | 6,888.64 | -3.20% ↓ | | Median TTFT (ms) | 4,295.84 | 4,128.82 | -3.89% ↓ | | P99 TTFT (ms) | 18,370.87 | 17,487.94 | -4.81% ↓ | | Mean TPOT (ms) | 245.78 | 240.14 | -2.29% ↓ | | Median TPOT (ms) | 264.03 | 259.18 | -1.84% ↓ | | P99 TPOT (ms) | 334.38 | 313.15 | -6.35% ↓ | | Mean ITL (ms) | 246.99 | 241.84 | -2.09% ↓ | | Median ITL (ms) | 117.71 | 121.08 | +2.86% ↑ | | P99 ITL (ms) | 1,327.55 | 1,470.33 | +10.76% ↑ | **🤖 AI Summary:** - The most notable improvement is in P99 TPOT, which dropped **-6.35%** from 334.38ms → 313.15ms, indicating reduced tail latency for per-token generation under heavy load. - TTFT improved across all percentiles: mean dropped **-3.20%** (7,116ms → 6,889ms), median **-3.89%** (4,296ms → 4,129ms), and P99 **-4.81%** (18,371ms → 17,488ms), reflecting faster time-to-first-token across the board. - TPOT also improved consistently, with mean down **-2.29%** (245.78ms → 240.14ms) and median down **-1.84%** (264.03ms → 259.18ms), showing a modest but steady reduction in per-token generation time. - Throughput saw a slight uplift of roughly **+1.6%** across request, output token, and total token throughput. Peak output token throughput jumped **+8.73%** (2,475 → 2,691 tok/s), suggesting better burst handling capacity. - P99 ITL increased **+10.76%** (1,328ms → 1,470ms), the largest regression in the run. Median ITL also ticked up **+2.86%** (117.71ms → 121.08ms). These tail-latency spikes may reflect scheduling variability under peak concurrency and could be within run-to-run noise, but are worth monitoring. - Overall, the PR delivers a consistent improvement in both throughput and latency, with the caveat that P99 inter-token latency regressed — likely a transient effect given that mean ITL still improved by **-2.09%**. --- - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: shen-shanshan <467638484@qq.com>
2026-03-23 15:24:26 +08:00
from vllm.v1.attention.backends.registry import AttentionBackendEnum
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
[MM][Perf] Use `seq_lens` CPU cache to avoid frequent d2h copy for better performance (#6448) ### What this PR does / why we need it? Currently, the performance of multi-modal encoding (i.e., `AscendMMEncoderAttention` forward) is considerably bounded by the heavy host pre-process operations. We can see from the profiling results below, before the real computation of Attention, there are long free time in the device, which will lead to extremely low NPU utilization. <img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39" src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd" /> --- **To opitimize this, this PR has proposed four changes:** 1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR, `AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in every forward, since the op `_npu_flash_attention_unpad()` requires CPU `cu_seqlens` (otherwise it will crash). Thus, we use `seq_lens_cpu_cache` to cache this tensor, since it's shared between all layers, but may change in different forward step. When the current `layer_index` is `0`, we update the cache, otherwise we directly use the cache to avoid frequent `diff` and `copy` operations, which are costful. 2. Pre-compute the scale value to avoid calculating it in every forward. 3. Move the judgment of `enable_pad` from forward to the `__init__` method. 4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204. **Performance after these optimizations:** - **TTFT** has been reduced by **7.43%** ⬇️. - **Throughput** has been increased by **1.23%** ⬆️. --- > [!NOTE] > This PR requires https://github.com/vllm-project/vllm/pull/33674 be merged. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --no-async-scheduling ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 82.23 Total input tokens: 33418 Total generated tokens: 61543 Request throughput (req/s): 6.08 Output token throughput (tok/s): 748.45 Peak output token throughput (tok/s): 3203.00 Peak concurrent requests: 402.00 Total token throughput (tok/s): 1154.86 ---------------Time to First Token---------------- Mean TTFT (ms): 10275.37 Median TTFT (ms): 6297.88 P99 TTFT (ms): 22918.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 263.02 Median TPOT (ms): 277.61 P99 TPOT (ms): 483.56 ---------------Inter-token Latency---------------- Mean ITL (ms): 257.31 Median ITL (ms): 94.83 P99 ITL (ms): 1773.90 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 81.20 Total input tokens: 33418 Total generated tokens: 61509 Request throughput (req/s): 6.16 Output token throughput (tok/s): 757.54 Peak output token throughput (tok/s): 2562.00 Peak concurrent requests: 395.00 Total token throughput (tok/s): 1169.11 ---------------Time to First Token---------------- Mean TTFT (ms): 9511.91 Median TTFT (ms): 5479.78 P99 TTFT (ms): 21427.21 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 261.12 Median TPOT (ms): 276.03 P99 TPOT (ms): 446.99 ---------------Inter-token Latency---------------- Mean ITL (ms): 254.04 Median ITL (ms): 97.71 P99 ITL (ms): 1516.67 ================================================== ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-26 08:49:36 +08:00
MIN_PAD_SIZE: int = 64 # min_size to pad weight
MAX_PAD_SIZE: int = 128 # max_size to pad weight
# Use seq_lens CPU cache to avoid frequent d2h copy.
# AscendMMEncoderAttention will copy the cu_seqlens from NPU to CPU in every
# forward, since the op _npu_flash_attention_unpad() requires CPU cu_seqlens
# (otherwise it will break down).
# Thus, we use seq_lens_cpu_cache to cache this tensor, since it's shared
# between all layers, but may change in different forward step. When the
# current layer_index is 0, we update the cache, otherwise we directly use the
# cache to avoid frequent diff and copy operations, which are costful.
seq_lens_cpu_cache: torch.Tensor = None
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
class AscendMMEncoderAttention(MMEncoderAttention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
) -> None:
"""
Args:
num_heads: number of attention heads per partition.
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MMEncoderAttention.
multimodal_config: configs for multi-modal.
"""
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
prefix=prefix,
)
[MM][Perf] Use `seq_lens` CPU cache to avoid frequent d2h copy for better performance (#6448) ### What this PR does / why we need it? Currently, the performance of multi-modal encoding (i.e., `AscendMMEncoderAttention` forward) is considerably bounded by the heavy host pre-process operations. We can see from the profiling results below, before the real computation of Attention, there are long free time in the device, which will lead to extremely low NPU utilization. <img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39" src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd" /> --- **To opitimize this, this PR has proposed four changes:** 1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR, `AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in every forward, since the op `_npu_flash_attention_unpad()` requires CPU `cu_seqlens` (otherwise it will crash). Thus, we use `seq_lens_cpu_cache` to cache this tensor, since it's shared between all layers, but may change in different forward step. When the current `layer_index` is `0`, we update the cache, otherwise we directly use the cache to avoid frequent `diff` and `copy` operations, which are costful. 2. Pre-compute the scale value to avoid calculating it in every forward. 3. Move the judgment of `enable_pad` from forward to the `__init__` method. 4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204. **Performance after these optimizations:** - **TTFT** has been reduced by **7.43%** ⬇️. - **Throughput** has been increased by **1.23%** ⬆️. --- > [!NOTE] > This PR requires https://github.com/vllm-project/vllm/pull/33674 be merged. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --no-async-scheduling ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 82.23 Total input tokens: 33418 Total generated tokens: 61543 Request throughput (req/s): 6.08 Output token throughput (tok/s): 748.45 Peak output token throughput (tok/s): 3203.00 Peak concurrent requests: 402.00 Total token throughput (tok/s): 1154.86 ---------------Time to First Token---------------- Mean TTFT (ms): 10275.37 Median TTFT (ms): 6297.88 P99 TTFT (ms): 22918.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 263.02 Median TPOT (ms): 277.61 P99 TPOT (ms): 483.56 ---------------Inter-token Latency---------------- Mean ITL (ms): 257.31 Median ITL (ms): 94.83 P99 ITL (ms): 1773.90 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 81.20 Total input tokens: 33418 Total generated tokens: 61509 Request throughput (req/s): 6.16 Output token throughput (tok/s): 757.54 Peak output token throughput (tok/s): 2562.00 Peak concurrent requests: 395.00 Total token throughput (tok/s): 1169.11 ---------------Time to First Token---------------- Mean TTFT (ms): 9511.91 Median TTFT (ms): 5479.78 P99 TTFT (ms): 21427.21 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 261.12 Median TPOT (ms): 276.03 P99 TPOT (ms): 446.99 ---------------Inter-token Latency---------------- Mean ITL (ms): 254.04 Median ITL (ms): 97.71 P99 ITL (ms): 1516.67 ================================================== ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-26 08:49:36 +08:00
self.enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
self.scale_value = self.head_size**-0.5
[MM][Perf] Pre-compute `seq_lens` and put it on CPU before ViT vision blocks for better performance (#7104) ### What this PR does / why we need it? **Background:** PR https://github.com/vllm-project/vllm-ascend/pull/6448 has introduced a `seq_lens` CPU cache mechanism, which will considerably benefit the performance for VL models but may lead to accuracy issues. Thus, we have reverted it. **Proposed Change:** In PR https://github.com/vllm-project/vllm/pull/36605, we have supported custom processing logic for OOT MMEncoder kernels in vLLM. Thus, we can pre-compute `seq_lens` (rather than `cu_seqlens`) and put it on CPU before ViT vision blocks to avoid redundant computation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? #### ✅ Functional Test Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" having a slightly bolder and more prominent appearance compared to \"Qwen.\" The overall design is simple and professional." ``` > [!NOTE] > Since PR https://github.com/vllm-project/vllm/pull/36605 only modified `Qwen3-VL` modeling files, this PR has no affect to `Qwen2.5-VL` model. --- Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with TONG." ``` --- #### ✅ Benchmark Launch the server: ``` vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 ``` Run benchmark: ``` vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 78.58 Total input tokens: 33418 Total generated tokens: 61431 Request throughput (req/s): 6.36 Output token throughput (tok/s): 781.78 Peak output token throughput (tok/s): 2475.00 Peak concurrent requests: 383.00 Total token throughput (tok/s): 1207.07 ---------------Time to First Token---------------- Mean TTFT (ms): 7116.24 Median TTFT (ms): 4295.84 P99 TTFT (ms): 18370.87 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 245.78 Median TPOT (ms): 264.03 P99 TPOT (ms): 334.38 ---------------Inter-token Latency---------------- Mean ITL (ms): 246.99 Median ITL (ms): 117.71 P99 ITL (ms): 1327.55 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 77.44 Total input tokens: 33418 Total generated tokens: 61522 Request throughput (req/s): 6.46 Output token throughput (tok/s): 794.40 Peak output token throughput (tok/s): 2691.00 Peak concurrent requests: 369.00 Total token throughput (tok/s): 1225.91 ---------------Time to First Token---------------- Mean TTFT (ms): 6888.64 Median TTFT (ms): 4128.82 P99 TTFT (ms): 17487.94 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 240.14 Median TPOT (ms): 259.18 P99 TPOT (ms): 313.15 ---------------Inter-token Latency---------------- Mean ITL (ms): 241.84 Median ITL (ms): 121.08 P99 ITL (ms): 1470.33 ================================================== ``` **Performance Metrics:** | Metric | Before this PR | After this PR | Comparison | | :----- | :------------- | :------------ | :--------- | | **Throughput** | | | | | Request throughput (req/s) | 6.36 | 6.46 | +1.57% ↑ | | Output token throughput (tok/s) | 781.78 | 794.40 | +1.61% ↑ | | Total token throughput (tok/s) | 1,207.07 | 1,225.91 | +1.56% ↑ | | Peak output token throughput (tok/s) | 2,475 | 2,691 | +8.73% ↑ | | **Latency** | | | | | Benchmark duration (s) | 78.58 | 77.44 | -1.45% ↓ | | Mean TTFT (ms) | 7,116.24 | 6,888.64 | -3.20% ↓ | | Median TTFT (ms) | 4,295.84 | 4,128.82 | -3.89% ↓ | | P99 TTFT (ms) | 18,370.87 | 17,487.94 | -4.81% ↓ | | Mean TPOT (ms) | 245.78 | 240.14 | -2.29% ↓ | | Median TPOT (ms) | 264.03 | 259.18 | -1.84% ↓ | | P99 TPOT (ms) | 334.38 | 313.15 | -6.35% ↓ | | Mean ITL (ms) | 246.99 | 241.84 | -2.09% ↓ | | Median ITL (ms) | 117.71 | 121.08 | +2.86% ↑ | | P99 ITL (ms) | 1,327.55 | 1,470.33 | +10.76% ↑ | **🤖 AI Summary:** - The most notable improvement is in P99 TPOT, which dropped **-6.35%** from 334.38ms → 313.15ms, indicating reduced tail latency for per-token generation under heavy load. - TTFT improved across all percentiles: mean dropped **-3.20%** (7,116ms → 6,889ms), median **-3.89%** (4,296ms → 4,129ms), and P99 **-4.81%** (18,371ms → 17,488ms), reflecting faster time-to-first-token across the board. - TPOT also improved consistently, with mean down **-2.29%** (245.78ms → 240.14ms) and median down **-1.84%** (264.03ms → 259.18ms), showing a modest but steady reduction in per-token generation time. - Throughput saw a slight uplift of roughly **+1.6%** across request, output token, and total token throughput. Peak output token throughput jumped **+8.73%** (2,475 → 2,691 tok/s), suggesting better burst handling capacity. - P99 ITL increased **+10.76%** (1,328ms → 1,470ms), the largest regression in the run. Median ITL also ticked up **+2.86%** (117.71ms → 121.08ms). These tail-latency spikes may reflect scheduling variability under peak concurrency and could be within run-to-run noise, but are worth monitoring. - Overall, the PR delivers a consistent improvement in both throughput and latency, with the caveat that P99 inter-token latency regressed — likely a transient effect given that mean ITL still improved by **-2.09%**. --- - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: shen-shanshan <467638484@qq.com>
2026-03-23 15:24:26 +08:00
@classmethod
def maybe_compute_seq_lens(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
device: torch.device,
) -> np.ndarray | None:
if cu_seqlens is None:
return None
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
seq_lens = torch.from_numpy(seq_lens).to("cpu", non_blocking=True)
return seq_lens
[MM][Perf] Use `seq_lens` CPU cache to avoid frequent d2h copy for better performance (#6448) ### What this PR does / why we need it? Currently, the performance of multi-modal encoding (i.e., `AscendMMEncoderAttention` forward) is considerably bounded by the heavy host pre-process operations. We can see from the profiling results below, before the real computation of Attention, there are long free time in the device, which will lead to extremely low NPU utilization. <img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39" src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd" /> --- **To opitimize this, this PR has proposed four changes:** 1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR, `AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in every forward, since the op `_npu_flash_attention_unpad()` requires CPU `cu_seqlens` (otherwise it will crash). Thus, we use `seq_lens_cpu_cache` to cache this tensor, since it's shared between all layers, but may change in different forward step. When the current `layer_index` is `0`, we update the cache, otherwise we directly use the cache to avoid frequent `diff` and `copy` operations, which are costful. 2. Pre-compute the scale value to avoid calculating it in every forward. 3. Move the judgment of `enable_pad` from forward to the `__init__` method. 4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204. **Performance after these optimizations:** - **TTFT** has been reduced by **7.43%** ⬇️. - **Throughput** has been increased by **1.23%** ⬆️. --- > [!NOTE] > This PR requires https://github.com/vllm-project/vllm/pull/33674 be merged. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --no-async-scheduling ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 82.23 Total input tokens: 33418 Total generated tokens: 61543 Request throughput (req/s): 6.08 Output token throughput (tok/s): 748.45 Peak output token throughput (tok/s): 3203.00 Peak concurrent requests: 402.00 Total token throughput (tok/s): 1154.86 ---------------Time to First Token---------------- Mean TTFT (ms): 10275.37 Median TTFT (ms): 6297.88 P99 TTFT (ms): 22918.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 263.02 Median TPOT (ms): 277.61 P99 TPOT (ms): 483.56 ---------------Inter-token Latency---------------- Mean ITL (ms): 257.31 Median ITL (ms): 94.83 P99 ITL (ms): 1773.90 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 81.20 Total input tokens: 33418 Total generated tokens: 61509 Request throughput (req/s): 6.16 Output token throughput (tok/s): 757.54 Peak output token throughput (tok/s): 2562.00 Peak concurrent requests: 395.00 Total token throughput (tok/s): 1169.11 ---------------Time to First Token---------------- Mean TTFT (ms): 9511.91 Median TTFT (ms): 5479.78 P99 TTFT (ms): 21427.21 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 261.12 Median TPOT (ms): 276.03 P99 TPOT (ms): 446.99 ---------------Inter-token Latency---------------- Mean ITL (ms): 254.04 Median ITL (ms): 97.71 P99 ITL (ms): 1516.67 ================================================== ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-26 08:49:36 +08:00
def _reshape_qkv_to_3d(
2025-12-23 23:52:11 +08:00
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
[MM][Perf] Pre-compute `seq_lens` and put it on CPU before ViT vision blocks for better performance (#7104) ### What this PR does / why we need it? **Background:** PR https://github.com/vllm-project/vllm-ascend/pull/6448 has introduced a `seq_lens` CPU cache mechanism, which will considerably benefit the performance for VL models but may lead to accuracy issues. Thus, we have reverted it. **Proposed Change:** In PR https://github.com/vllm-project/vllm/pull/36605, we have supported custom processing logic for OOT MMEncoder kernels in vLLM. Thus, we can pre-compute `seq_lens` (rather than `cu_seqlens`) and put it on CPU before ViT vision blocks to avoid redundant computation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? #### ✅ Functional Test Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" having a slightly bolder and more prominent appearance compared to \"Qwen.\" The overall design is simple and professional." ``` > [!NOTE] > Since PR https://github.com/vllm-project/vllm/pull/36605 only modified `Qwen3-VL` modeling files, this PR has no affect to `Qwen2.5-VL` model. --- Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with TONG." ``` --- #### ✅ Benchmark Launch the server: ``` vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 ``` Run benchmark: ``` vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 78.58 Total input tokens: 33418 Total generated tokens: 61431 Request throughput (req/s): 6.36 Output token throughput (tok/s): 781.78 Peak output token throughput (tok/s): 2475.00 Peak concurrent requests: 383.00 Total token throughput (tok/s): 1207.07 ---------------Time to First Token---------------- Mean TTFT (ms): 7116.24 Median TTFT (ms): 4295.84 P99 TTFT (ms): 18370.87 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 245.78 Median TPOT (ms): 264.03 P99 TPOT (ms): 334.38 ---------------Inter-token Latency---------------- Mean ITL (ms): 246.99 Median ITL (ms): 117.71 P99 ITL (ms): 1327.55 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 77.44 Total input tokens: 33418 Total generated tokens: 61522 Request throughput (req/s): 6.46 Output token throughput (tok/s): 794.40 Peak output token throughput (tok/s): 2691.00 Peak concurrent requests: 369.00 Total token throughput (tok/s): 1225.91 ---------------Time to First Token---------------- Mean TTFT (ms): 6888.64 Median TTFT (ms): 4128.82 P99 TTFT (ms): 17487.94 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 240.14 Median TPOT (ms): 259.18 P99 TPOT (ms): 313.15 ---------------Inter-token Latency---------------- Mean ITL (ms): 241.84 Median ITL (ms): 121.08 P99 ITL (ms): 1470.33 ================================================== ``` **Performance Metrics:** | Metric | Before this PR | After this PR | Comparison | | :----- | :------------- | :------------ | :--------- | | **Throughput** | | | | | Request throughput (req/s) | 6.36 | 6.46 | +1.57% ↑ | | Output token throughput (tok/s) | 781.78 | 794.40 | +1.61% ↑ | | Total token throughput (tok/s) | 1,207.07 | 1,225.91 | +1.56% ↑ | | Peak output token throughput (tok/s) | 2,475 | 2,691 | +8.73% ↑ | | **Latency** | | | | | Benchmark duration (s) | 78.58 | 77.44 | -1.45% ↓ | | Mean TTFT (ms) | 7,116.24 | 6,888.64 | -3.20% ↓ | | Median TTFT (ms) | 4,295.84 | 4,128.82 | -3.89% ↓ | | P99 TTFT (ms) | 18,370.87 | 17,487.94 | -4.81% ↓ | | Mean TPOT (ms) | 245.78 | 240.14 | -2.29% ↓ | | Median TPOT (ms) | 264.03 | 259.18 | -1.84% ↓ | | P99 TPOT (ms) | 334.38 | 313.15 | -6.35% ↓ | | Mean ITL (ms) | 246.99 | 241.84 | -2.09% ↓ | | Median ITL (ms) | 117.71 | 121.08 | +2.86% ↑ | | P99 ITL (ms) | 1,327.55 | 1,470.33 | +10.76% ↑ | **🤖 AI Summary:** - The most notable improvement is in P99 TPOT, which dropped **-6.35%** from 334.38ms → 313.15ms, indicating reduced tail latency for per-token generation under heavy load. - TTFT improved across all percentiles: mean dropped **-3.20%** (7,116ms → 6,889ms), median **-3.89%** (4,296ms → 4,129ms), and P99 **-4.81%** (18,371ms → 17,488ms), reflecting faster time-to-first-token across the board. - TPOT also improved consistently, with mean down **-2.29%** (245.78ms → 240.14ms) and median down **-1.84%** (264.03ms → 259.18ms), showing a modest but steady reduction in per-token generation time. - Throughput saw a slight uplift of roughly **+1.6%** across request, output token, and total token throughput. Peak output token throughput jumped **+8.73%** (2,475 → 2,691 tok/s), suggesting better burst handling capacity. - P99 ITL increased **+10.76%** (1,328ms → 1,470ms), the largest regression in the run. Median ITL also ticked up **+2.86%** (117.71ms → 121.08ms). These tail-latency spikes may reflect scheduling variability under peak concurrency and could be within run-to-run noise, but are worth monitoring. - Overall, the PR delivers a consistent improvement in both throughput and latency, with the caveat that P99 inter-token latency regressed — likely a transient effect given that mean ITL still improved by **-2.09%**. --- - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: shen-shanshan <467638484@qq.com>
2026-03-23 15:24:26 +08:00
def _maybe_compute_cu_seqlens(
self,
bsz: int,
q_len: int,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
if cu_seqlens is not None:
return cu_seqlens
# If cu_seqlens is not provided, we create a default one assuming all sequences have the same length.
# This is used by models such as Hunyuan-OCR, which always pass None as cu_seqlens and rely on the operator to
# compute it internally.
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu")
return cu_seqlens
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
def forward_oot(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
[Main2Main] Upgrade vLLM to 0303 (#6944) ### What this PR does / why we need it? break: - https://github.com/vllm-project/vllm/pull/34102 Disable_full param replaced with valid_modes/invalid_modes API - https://github.com/vllm-project/vllm/pull/35503 Now must return float compilation_time - https://github.com/vllm-project/vllm/pull/35564 New sequence_lengths param added - https://github.com/vllm-project/vllm/pull/33807 A check was performed (if runner_backend != "auto") - https://github.com/vllm-project/vllm/pull/34861 `BaseDeviceCommunicator` now accesses PyTorch's internal `pg_map` to check process group state - https://github.com/vllm-project/vllm/pull/35274 **Important change:** - https://github.com/vllm-project/vllm/pull/28672 `matcher_utils` directly accesses `torch.ops._C.*` during the import phase. In the Ascend environment, some unregistered ops trigger `AttributeError`, causing e2e initialization failure. https://github.com/vllm-project/vllm-ascend/actions/runs/22607260487/job/65502047131#step:10:2323 https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/matcher_utils.py#L29 This PR adds temporary compatibility placeholders (rms_norm, fused_add_rms_norm, rotate_embedding, static/dynamic fp8 quant, silu_and_mul) to `vllm_ascend/patch/platform/patch_fusion_matcher_compat_ops.py` to ensure no crashes during the import phase. Upstream repairs will be considered later. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: MrZ20 <2609716663@qq.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: Meihan-chen <jcccx.cmh@gmail.com> Co-authored-by: Claude Code <noreply@anthropic.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com>
2026-03-06 09:08:52 +08:00
sequence_lengths: torch.Tensor | None = None,
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
):
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() == 4
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
[MM][Perf] Pre-compute `seq_lens` and put it on CPU before ViT vision blocks for better performance (#7104) ### What this PR does / why we need it? **Background:** PR https://github.com/vllm-project/vllm-ascend/pull/6448 has introduced a `seq_lens` CPU cache mechanism, which will considerably benefit the performance for VL models but may lead to accuracy issues. Thus, we have reverted it. **Proposed Change:** In PR https://github.com/vllm-project/vllm/pull/36605, we have supported custom processing logic for OOT MMEncoder kernels in vLLM. Thus, we can pre-compute `seq_lens` (rather than `cu_seqlens`) and put it on CPU before ViT vision blocks to avoid redundant computation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? #### ✅ Functional Test Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" having a slightly bolder and more prominent appearance compared to \"Qwen.\" The overall design is simple and professional." ``` > [!NOTE] > Since PR https://github.com/vllm-project/vllm/pull/36605 only modified `Qwen3-VL` modeling files, this PR has no affect to `Qwen2.5-VL` model. --- Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --limit-mm-per-prompt '{"image": 1}' ``` Output: ```bash "The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with TONG." ``` --- #### ✅ Benchmark Launch the server: ``` vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 ``` Run benchmark: ``` vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 78.58 Total input tokens: 33418 Total generated tokens: 61431 Request throughput (req/s): 6.36 Output token throughput (tok/s): 781.78 Peak output token throughput (tok/s): 2475.00 Peak concurrent requests: 383.00 Total token throughput (tok/s): 1207.07 ---------------Time to First Token---------------- Mean TTFT (ms): 7116.24 Median TTFT (ms): 4295.84 P99 TTFT (ms): 18370.87 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 245.78 Median TPOT (ms): 264.03 P99 TPOT (ms): 334.38 ---------------Inter-token Latency---------------- Mean ITL (ms): 246.99 Median ITL (ms): 117.71 P99 ITL (ms): 1327.55 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 77.44 Total input tokens: 33418 Total generated tokens: 61522 Request throughput (req/s): 6.46 Output token throughput (tok/s): 794.40 Peak output token throughput (tok/s): 2691.00 Peak concurrent requests: 369.00 Total token throughput (tok/s): 1225.91 ---------------Time to First Token---------------- Mean TTFT (ms): 6888.64 Median TTFT (ms): 4128.82 P99 TTFT (ms): 17487.94 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 240.14 Median TPOT (ms): 259.18 P99 TPOT (ms): 313.15 ---------------Inter-token Latency---------------- Mean ITL (ms): 241.84 Median ITL (ms): 121.08 P99 ITL (ms): 1470.33 ================================================== ``` **Performance Metrics:** | Metric | Before this PR | After this PR | Comparison | | :----- | :------------- | :------------ | :--------- | | **Throughput** | | | | | Request throughput (req/s) | 6.36 | 6.46 | +1.57% ↑ | | Output token throughput (tok/s) | 781.78 | 794.40 | +1.61% ↑ | | Total token throughput (tok/s) | 1,207.07 | 1,225.91 | +1.56% ↑ | | Peak output token throughput (tok/s) | 2,475 | 2,691 | +8.73% ↑ | | **Latency** | | | | | Benchmark duration (s) | 78.58 | 77.44 | -1.45% ↓ | | Mean TTFT (ms) | 7,116.24 | 6,888.64 | -3.20% ↓ | | Median TTFT (ms) | 4,295.84 | 4,128.82 | -3.89% ↓ | | P99 TTFT (ms) | 18,370.87 | 17,487.94 | -4.81% ↓ | | Mean TPOT (ms) | 245.78 | 240.14 | -2.29% ↓ | | Median TPOT (ms) | 264.03 | 259.18 | -1.84% ↓ | | P99 TPOT (ms) | 334.38 | 313.15 | -6.35% ↓ | | Mean ITL (ms) | 246.99 | 241.84 | -2.09% ↓ | | Median ITL (ms) | 117.71 | 121.08 | +2.86% ↑ | | P99 ITL (ms) | 1,327.55 | 1,470.33 | +10.76% ↑ | **🤖 AI Summary:** - The most notable improvement is in P99 TPOT, which dropped **-6.35%** from 334.38ms → 313.15ms, indicating reduced tail latency for per-token generation under heavy load. - TTFT improved across all percentiles: mean dropped **-3.20%** (7,116ms → 6,889ms), median **-3.89%** (4,296ms → 4,129ms), and P99 **-4.81%** (18,371ms → 17,488ms), reflecting faster time-to-first-token across the board. - TPOT also improved consistently, with mean down **-2.29%** (245.78ms → 240.14ms) and median down **-1.84%** (264.03ms → 259.18ms), showing a modest but steady reduction in per-token generation time. - Throughput saw a slight uplift of roughly **+1.6%** across request, output token, and total token throughput. Peak output token throughput jumped **+8.73%** (2,475 → 2,691 tok/s), suggesting better burst handling capacity. - P99 ITL increased **+10.76%** (1,328ms → 1,470ms), the largest regression in the run. Median ITL also ticked up **+2.86%** (117.71ms → 121.08ms). These tail-latency spikes may reflect scheduling variability under peak concurrency and could be within run-to-run noise, but are worth monitoring. - Overall, the PR delivers a consistent improvement in both throughput and latency, with the caveat that P99 inter-token latency regressed — likely a transient effect given that mean ITL still improved by **-2.09%**. --- - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: shen-shanshan <467638484@qq.com>
2026-03-23 15:24:26 +08:00
if sequence_lengths is not None:
# Use pre-compute seq_lens before vision blocks.
if sequence_lengths.device.type != "cpu":
sequence_lengths = sequence_lengths.to("cpu")
seq_lens_cpu = sequence_lengths
else:
# Convert cu_seqlens to seq_lens and move it to CPU, since FA requires CPU seq_lens.
# NOTE: This will considerably hurt performance.
cu_seqlens = self._maybe_compute_cu_seqlens(bsz, q_len, cu_seqlens)
seq_lens_cpu = torch.diff(cu_seqlens).to("cpu")
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
[MM][Perf] Use `seq_lens` CPU cache to avoid frequent d2h copy for better performance (#6448) ### What this PR does / why we need it? Currently, the performance of multi-modal encoding (i.e., `AscendMMEncoderAttention` forward) is considerably bounded by the heavy host pre-process operations. We can see from the profiling results below, before the real computation of Attention, there are long free time in the device, which will lead to extremely low NPU utilization. <img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39" src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd" /> --- **To opitimize this, this PR has proposed four changes:** 1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR, `AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in every forward, since the op `_npu_flash_attention_unpad()` requires CPU `cu_seqlens` (otherwise it will crash). Thus, we use `seq_lens_cpu_cache` to cache this tensor, since it's shared between all layers, but may change in different forward step. When the current `layer_index` is `0`, we update the cache, otherwise we directly use the cache to avoid frequent `diff` and `copy` operations, which are costful. 2. Pre-compute the scale value to avoid calculating it in every forward. 3. Move the judgment of `enable_pad` from forward to the `__init__` method. 4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204. **Performance after these optimizations:** - **TTFT** has been reduced by **7.43%** ⬇️. - **Throughput** has been increased by **1.23%** ⬆️. --- > [!NOTE] > This PR requires https://github.com/vllm-project/vllm/pull/33674 be merged. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --no-async-scheduling ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 82.23 Total input tokens: 33418 Total generated tokens: 61543 Request throughput (req/s): 6.08 Output token throughput (tok/s): 748.45 Peak output token throughput (tok/s): 3203.00 Peak concurrent requests: 402.00 Total token throughput (tok/s): 1154.86 ---------------Time to First Token---------------- Mean TTFT (ms): 10275.37 Median TTFT (ms): 6297.88 P99 TTFT (ms): 22918.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 263.02 Median TPOT (ms): 277.61 P99 TPOT (ms): 483.56 ---------------Inter-token Latency---------------- Mean ITL (ms): 257.31 Median ITL (ms): 94.83 P99 ITL (ms): 1773.90 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 81.20 Total input tokens: 33418 Total generated tokens: 61509 Request throughput (req/s): 6.16 Output token throughput (tok/s): 757.54 Peak output token throughput (tok/s): 2562.00 Peak concurrent requests: 395.00 Total token throughput (tok/s): 1169.11 ---------------Time to First Token---------------- Mean TTFT (ms): 9511.91 Median TTFT (ms): 5479.78 P99 TTFT (ms): 21427.21 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 261.12 Median TPOT (ms): 276.03 P99 TPOT (ms): 446.99 ---------------Inter-token Latency---------------- Mean ITL (ms): 254.04 Median ITL (ms): 97.71 P99 ITL (ms): 1516.67 ================================================== ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-26 08:49:36 +08:00
# q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim]
q, k, v = self._reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
[MM][Perf] Use `seq_lens` CPU cache to avoid frequent d2h copy for better performance (#6448) ### What this PR does / why we need it? Currently, the performance of multi-modal encoding (i.e., `AscendMMEncoderAttention` forward) is considerably bounded by the heavy host pre-process operations. We can see from the profiling results below, before the real computation of Attention, there are long free time in the device, which will lead to extremely low NPU utilization. <img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39" src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd" /> --- **To opitimize this, this PR has proposed four changes:** 1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR, `AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in every forward, since the op `_npu_flash_attention_unpad()` requires CPU `cu_seqlens` (otherwise it will crash). Thus, we use `seq_lens_cpu_cache` to cache this tensor, since it's shared between all layers, but may change in different forward step. When the current `layer_index` is `0`, we update the cache, otherwise we directly use the cache to avoid frequent `diff` and `copy` operations, which are costful. 2. Pre-compute the scale value to avoid calculating it in every forward. 3. Move the judgment of `enable_pad` from forward to the `__init__` method. 4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204. **Performance after these optimizations:** - **TTFT** has been reduced by **7.43%** ⬇️. - **Throughput** has been increased by **1.23%** ⬆️. --- > [!NOTE] > This PR requires https://github.com/vllm-project/vllm/pull/33674 be merged. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --no-async-scheduling ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 82.23 Total input tokens: 33418 Total generated tokens: 61543 Request throughput (req/s): 6.08 Output token throughput (tok/s): 748.45 Peak output token throughput (tok/s): 3203.00 Peak concurrent requests: 402.00 Total token throughput (tok/s): 1154.86 ---------------Time to First Token---------------- Mean TTFT (ms): 10275.37 Median TTFT (ms): 6297.88 P99 TTFT (ms): 22918.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 263.02 Median TPOT (ms): 277.61 P99 TPOT (ms): 483.56 ---------------Inter-token Latency---------------- Mean ITL (ms): 257.31 Median ITL (ms): 94.83 P99 ITL (ms): 1773.90 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 81.20 Total input tokens: 33418 Total generated tokens: 61509 Request throughput (req/s): 6.16 Output token throughput (tok/s): 757.54 Peak output token throughput (tok/s): 2562.00 Peak concurrent requests: 395.00 Total token throughput (tok/s): 1169.11 ---------------Time to First Token---------------- Mean TTFT (ms): 9511.91 Median TTFT (ms): 5479.78 P99 TTFT (ms): 21427.21 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 261.12 Median TPOT (ms): 276.03 P99 TPOT (ms): 446.99 ---------------Inter-token Latency---------------- Mean ITL (ms): 254.04 Median ITL (ms): 97.71 P99 ITL (ms): 1516.67 ================================================== ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-26 08:49:36 +08:00
if self.enable_pad:
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
origin_shape = q.shape[-1]
pad_len = MAX_PAD_SIZE - origin_shape
[MM][Perf] Use `seq_lens` CPU cache to avoid frequent d2h copy for better performance (#6448) ### What this PR does / why we need it? Currently, the performance of multi-modal encoding (i.e., `AscendMMEncoderAttention` forward) is considerably bounded by the heavy host pre-process operations. We can see from the profiling results below, before the real computation of Attention, there are long free time in the device, which will lead to extremely low NPU utilization. <img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39" src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd" /> --- **To opitimize this, this PR has proposed four changes:** 1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR, `AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in every forward, since the op `_npu_flash_attention_unpad()` requires CPU `cu_seqlens` (otherwise it will crash). Thus, we use `seq_lens_cpu_cache` to cache this tensor, since it's shared between all layers, but may change in different forward step. When the current `layer_index` is `0`, we update the cache, otherwise we directly use the cache to avoid frequent `diff` and `copy` operations, which are costful. 2. Pre-compute the scale value to avoid calculating it in every forward. 3. Move the judgment of `enable_pad` from forward to the `__init__` method. 4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204. **Performance after these optimizations:** - **TTFT** has been reduced by **7.43%** ⬇️. - **Throughput** has been increased by **1.23%** ⬆️. --- > [!NOTE] > This PR requires https://github.com/vllm-project/vllm/pull/33674 be merged. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --no-async-scheduling ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 82.23 Total input tokens: 33418 Total generated tokens: 61543 Request throughput (req/s): 6.08 Output token throughput (tok/s): 748.45 Peak output token throughput (tok/s): 3203.00 Peak concurrent requests: 402.00 Total token throughput (tok/s): 1154.86 ---------------Time to First Token---------------- Mean TTFT (ms): 10275.37 Median TTFT (ms): 6297.88 P99 TTFT (ms): 22918.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 263.02 Median TPOT (ms): 277.61 P99 TPOT (ms): 483.56 ---------------Inter-token Latency---------------- Mean ITL (ms): 257.31 Median ITL (ms): 94.83 P99 ITL (ms): 1773.90 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 81.20 Total input tokens: 33418 Total generated tokens: 61509 Request throughput (req/s): 6.16 Output token throughput (tok/s): 757.54 Peak output token throughput (tok/s): 2562.00 Peak concurrent requests: 395.00 Total token throughput (tok/s): 1169.11 ---------------Time to First Token---------------- Mean TTFT (ms): 9511.91 Median TTFT (ms): 5479.78 P99 TTFT (ms): 21427.21 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 261.12 Median TPOT (ms): 276.03 P99 TPOT (ms): 446.99 ---------------Inter-token Latency---------------- Mean ITL (ms): 254.04 Median ITL (ms): 97.71 P99 ITL (ms): 1516.67 ================================================== ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-26 08:49:36 +08:00
# [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
q = F.pad(q, (0, pad_len), mode="constant", value=0)
k = F.pad(k, (0, pad_len), mode="constant", value=0)
v = F.pad(v, (0, pad_len), mode="constant", value=0)
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
seq_lens_cpu = list(seq_lens_cpu.cumsum(0))
2025-12-23 23:52:11 +08:00
context_layer = torch_npu.npu_fusion_attention(
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
query=q,
key=k,
value=v,
actual_seq_qlen=seq_lens_cpu,
actual_seq_kvlen=seq_lens_cpu,
head_num=self.num_heads,
scale=self.scale_value,
input_layout="TND",
)[0]
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
[MM][Perf] Use `seq_lens` CPU cache to avoid frequent d2h copy for better performance (#6448) ### What this PR does / why we need it? Currently, the performance of multi-modal encoding (i.e., `AscendMMEncoderAttention` forward) is considerably bounded by the heavy host pre-process operations. We can see from the profiling results below, before the real computation of Attention, there are long free time in the device, which will lead to extremely low NPU utilization. <img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39" src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd" /> --- **To opitimize this, this PR has proposed four changes:** 1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR, `AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in every forward, since the op `_npu_flash_attention_unpad()` requires CPU `cu_seqlens` (otherwise it will crash). Thus, we use `seq_lens_cpu_cache` to cache this tensor, since it's shared between all layers, but may change in different forward step. When the current `layer_index` is `0`, we update the cache, otherwise we directly use the cache to avoid frequent `diff` and `copy` operations, which are costful. 2. Pre-compute the scale value to avoid calculating it in every forward. 3. Move the judgment of `enable_pad` from forward to the `__init__` method. 4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204. **Performance after these optimizations:** - **TTFT** has been reduced by **7.43%** ⬇️. - **Throughput** has been increased by **1.23%** ⬆️. --- > [!NOTE] > This PR requires https://github.com/vllm-project/vllm/pull/33674 be merged. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ --no-async-scheduling ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 500 \ --request-rate 10 \ --burstiness 5 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 82.23 Total input tokens: 33418 Total generated tokens: 61543 Request throughput (req/s): 6.08 Output token throughput (tok/s): 748.45 Peak output token throughput (tok/s): 3203.00 Peak concurrent requests: 402.00 Total token throughput (tok/s): 1154.86 ---------------Time to First Token---------------- Mean TTFT (ms): 10275.37 Median TTFT (ms): 6297.88 P99 TTFT (ms): 22918.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 263.02 Median TPOT (ms): 277.61 P99 TPOT (ms): 483.56 ---------------Inter-token Latency---------------- Mean ITL (ms): 257.31 Median ITL (ms): 94.83 P99 ITL (ms): 1773.90 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 500 Failed requests: 0 Request rate configured (RPS): 10.00 Benchmark duration (s): 81.20 Total input tokens: 33418 Total generated tokens: 61509 Request throughput (req/s): 6.16 Output token throughput (tok/s): 757.54 Peak output token throughput (tok/s): 2562.00 Peak concurrent requests: 395.00 Total token throughput (tok/s): 1169.11 ---------------Time to First Token---------------- Mean TTFT (ms): 9511.91 Median TTFT (ms): 5479.78 P99 TTFT (ms): 21427.21 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 261.12 Median TPOT (ms): 276.03 P99 TPOT (ms): 446.99 ---------------Inter-token Latency---------------- Mean ITL (ms): 254.04 Median ITL (ms): 97.71 P99 ITL (ms): 1516.67 ================================================== ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-26 08:49:36 +08:00
if self.enable_pad:
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
context_layer = context_layer[..., :origin_shape]
if is_reshaped:
context_layer = einops.rearrange(context_layer, "(b s) h d -> b s h d", b=bsz).contiguous()
else:
context_layer = einops.rearrange(context_layer, "(b s) h d -> b s (h d)", b=bsz).contiguous()
[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750) ### What this PR does / why we need it? Following https://github.com/vllm-project/vllm/pull/30125, register `AscendMMEncoderAttention` CustomOp and remove related patch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ✅ Run Qwen2.5-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` ✅ Run Qwen3-VL: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --max_model_len 16384 ``` Output: ``` {"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
2025-12-22 14:32:53 +08:00
return context_layer