From 76ac688388a3f6d16b9bb7822cb9f9648ba9b955 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Mon, 26 Jan 2026 10:20:24 +0800 Subject: [PATCH] [MM][Perf] Parallelize Q/K/V padding in AscendMMEncoderAttention for better performance (#6204) ### What this PR does / why we need it? Currently, we pad the last dim of qkv to 128 before flash attention (in `AscendMMEncoderAttention`) to get better performance on Ascend NPU. However, the qkv padding is executed serially, which may lead to more overhead when launching `aclnnConstantPadNd` (launch 3 times). Since the three operations are mutually independent, we stack qkv first and then pad them in one kernel launch. With this optimization, **TTFT** has been reduced by **3.15%**, **peak throughput** has been increased by **4.20%**. --- ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Launch the server: ```bash vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --dtype bfloat16 \ --limit-mm-per-prompt '{"image": 1}' \ --max-model-len 16384 \ --max-num-batched-tokens 16384 ``` Run benchmark: ```bash vllm bench serve \ --model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --hf-split train \ --dataset-path lmarena-ai/vision-arena-bench-v0.1 \ --num-prompts 1000 \ --no-stream ``` Before this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 1000 Failed requests: 0 Benchmark duration (s): 122.33 Total input tokens: 66638 Total generated tokens: 122845 Request throughput (req/s): 8.17 Output token throughput (tok/s): 1004.18 Peak output token throughput (tok/s): 3073.00 Peak concurrent requests: 1000.00 Total token throughput (tok/s): 1548.90 ---------------Time to First Token---------------- Mean TTFT (ms): 51757.16 Median TTFT (ms): 44853.42 P99 TTFT (ms): 110700.14 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 226.06 Median TPOT (ms): 206.85 P99 TPOT (ms): 935.31 ---------------Inter-token Latency---------------- Mean ITL (ms): 208.82 Median ITL (ms): 96.37 P99 ITL (ms): 2183.13 ================================================== ``` After this PR: ``` ============ Serving Benchmark Result ============ Successful requests: 1000 Failed requests: 0 Benchmark duration (s): 121.47 Total input tokens: 66638 Total generated tokens: 122860 Request throughput (req/s): 8.23 Output token throughput (tok/s): 1011.47 Peak output token throughput (tok/s): 3202.00 Peak concurrent requests: 1000.00 Total token throughput (tok/s): 1560.08 ---------------Time to First Token---------------- Mean TTFT (ms): 50125.08 Median TTFT (ms): 46270.85 P99 TTFT (ms): 108107.12 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 227.11 Median TPOT (ms): 205.13 P99 TPOT (ms): 816.08 ---------------Inter-token Latency---------------- Mean ITL (ms): 204.60 Median ITL (ms): 92.66 P99 ITL (ms): 2219.02 ================================================== ``` - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/ops/mm_encoder_attention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 081bc45f..19f44066 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -106,10 +106,12 @@ class AscendMMEncoderAttention(MMEncoderAttention): if enable_pad: origin_shape = q.shape[-1] pad_len = MAX_PAD_SIZE - origin_shape - # q, k, v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE] - q = F.pad(q, (0, pad_len), mode="constant", value=0) - k = F.pad(k, (0, pad_len), mode="constant", value=0) - v = F.pad(v, (0, pad_len), mode="constant", value=0) + # Merge qkv to reduce the overhead of launching npu pad operation. + # [3, b*s, head, head_dim] + qkv = torch.stack([q, k, v], dim=0) + # qkv: [3, b * s, head, head_dim] -> [3, b * s, head, MAX_PAD_SIZE] + qkv = F.pad(qkv, (0, pad_len), mode="constant", value=0) + q, k, v = qkv.unbind(dim=0) context_layer = torch.empty_like(q)