[MM][Perf] Use seq_lens CPU cache to avoid frequent d2h copy for better performance (#6448)

### What this PR does / why we need it?

Currently, the performance of multi-modal encoding (i.e.,
`AscendMMEncoderAttention` forward) is considerably bounded by the heavy
host pre-process operations.

We can see from the profiling results below, before the real computation
of Attention, there are long free time in the device, which will lead to
extremely low NPU utilization.

<img width="2264" height="1398" alt="iShot_2026-01-23_16 26 39"
src="https://github.com/user-attachments/assets/37f21d06-e526-4f28-82fe-005746cf13bd"
/>

---
**To opitimize this, this PR has proposed four changes:**

1. Use `seq_lens` CPU cache to avoid frequent d2h copy. Before this PR,
`AscendMMEncoderAttention` will copy the `cu_seqlens` from NPU to CPU in
every forward, since the op `_npu_flash_attention_unpad()` requires CPU
`cu_seqlens` (otherwise it will crash). Thus, we use
`seq_lens_cpu_cache` to cache this tensor, since it's shared between all
layers, but may change in different forward step. When the current
`layer_index` is `0`, we update the cache, otherwise we directly use the
cache to avoid frequent `diff` and `copy` operations, which are costful.
2. Pre-compute the scale value to avoid calculating it in every forward.
3. Move the judgment of `enable_pad` from forward to the `__init__`
method.
4. Revert https://github.com/vllm-project/vllm-ascend/pull/6204.

**Performance after these optimizations:**

- **TTFT** has been reduced by **7.43%** ⬇️.
- **Throughput** has been increased by **1.23%** ⬆️.

---
> [!NOTE]
> This PR requires https://github.com/vllm-project/vllm/pull/33674 be
merged.

---
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

Launch the server:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--dtype bfloat16 \
--limit-mm-per-prompt '{"image": 1}' \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--no-async-scheduling
```

Run benchmark:

```bash
vllm bench serve \
--model /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--hf-split train \
--dataset-path lmarena-ai/vision-arena-bench-v0.1 \
--num-prompts 500 \
--request-rate 10 \
--burstiness 5 \
--no-stream
```

Before this PR:

```
============ Serving Benchmark Result ============
Successful requests:                     500       
Failed requests:                         0         
Request rate configured (RPS):           10.00     
Benchmark duration (s):                  82.23     
Total input tokens:                      33418     
Total generated tokens:                  61543     
Request throughput (req/s):              6.08      
Output token throughput (tok/s):         748.45    
Peak output token throughput (tok/s):    3203.00   
Peak concurrent requests:                402.00    
Total token throughput (tok/s):          1154.86   
---------------Time to First Token----------------
Mean TTFT (ms):                          10275.37  
Median TTFT (ms):                        6297.88   
P99 TTFT (ms):                           22918.26  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          263.02    
Median TPOT (ms):                        277.61    
P99 TPOT (ms):                           483.56    
---------------Inter-token Latency----------------
Mean ITL (ms):                           257.31    
Median ITL (ms):                         94.83     
P99 ITL (ms):                            1773.90   
==================================================
```

After this PR:

```
============ Serving Benchmark Result ============
Successful requests:                     500       
Failed requests:                         0         
Request rate configured (RPS):           10.00     
Benchmark duration (s):                  81.20     
Total input tokens:                      33418     
Total generated tokens:                  61509     
Request throughput (req/s):              6.16      
Output token throughput (tok/s):         757.54    
Peak output token throughput (tok/s):    2562.00   
Peak concurrent requests:                395.00    
Total token throughput (tok/s):          1169.11   
---------------Time to First Token----------------
Mean TTFT (ms):                          9511.91   
Median TTFT (ms):                        5479.78   
P99 TTFT (ms):                           21427.21  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          261.12    
Median TPOT (ms):                        276.03    
P99 TPOT (ms):                           446.99    
---------------Inter-token Latency----------------
Mean ITL (ms):                           254.04    
Median ITL (ms):                         97.71     
P99 ITL (ms):                            1516.67   
==================================================
```

- vLLM version: v0.15.0
- vLLM main:
dc917cceb8

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2026-02-26 08:49:36 +08:00
committed by GitHub
parent 29e3cdde20
commit 3a4292e5b7

View File

@@ -21,8 +21,20 @@ import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
from vllm_ascend.utils import vllm_version_is
MIN_PAD_SIZE: int = 64 # min_size to pad weight
MAX_PAD_SIZE: int = 128 # max_size to pad weight
# Use seq_lens CPU cache to avoid frequent d2h copy.
# AscendMMEncoderAttention will copy the cu_seqlens from NPU to CPU in every
# forward, since the op _npu_flash_attention_unpad() requires CPU cu_seqlens
# (otherwise it will break down).
# Thus, we use seq_lens_cpu_cache to cache this tensor, since it's shared
# between all layers, but may change in different forward step. When the
# current layer_index is 0, we update the cache, otherwise we directly use the
# cache to avoid frequent diff and copy operations, which are costful.
seq_lens_cpu_cache: torch.Tensor = None
class AscendMMEncoderAttention(MMEncoderAttention):
@@ -52,7 +64,13 @@ class AscendMMEncoderAttention(MMEncoderAttention):
prefix=prefix,
)
def reshape_qkv_to_3d(
if not vllm_version_is("0.15.0"):
self.layer_index = int("".join(filter(str.isdigit, prefix)))
self.enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
self.scale_value = self.head_size**-0.5
def _reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
@@ -88,41 +106,46 @@ class AscendMMEncoderAttention(MMEncoderAttention):
kv_len = key.size(1)
is_reshaped = query.dim() == 4
if vllm_version_is("0.15.0"):
if cu_seqlens is None:
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu")
seq_lens_cpu = torch.diff(cu_seqlens).to("cpu")
else:
global seq_lens_cpu_cache
if self.layer_index == 0:
if cu_seqlens is None:
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu")
# Update seq_lens cpu cache.
seq_lens_cpu_cache = torch.diff(cu_seqlens).to("cpu")
# Directly use seq_lens cpu cache to avoid d2h copy.
seq_lens_cpu = seq_lens_cpu_cache
# q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim]
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
q, k, v = self._reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
if enable_pad:
if self.enable_pad:
origin_shape = q.shape[-1]
pad_len = MAX_PAD_SIZE - origin_shape
# Merge qkv to reduce the overhead of launching npu pad operation.
# [3, b*s, head, head_dim]
qkv = torch.stack([q, k, v], dim=0)
# qkv: [3, b * s, head, head_dim] -> [3, b * s, head, MAX_PAD_SIZE]
qkv = F.pad(qkv, (0, pad_len), mode="constant", value=0)
q, k, v = qkv.unbind(dim=0)
# [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
q = F.pad(q, (0, pad_len), mode="constant", value=0)
k = F.pad(k, (0, pad_len), mode="constant", value=0)
v = F.pad(v, (0, pad_len), mode="constant", value=0)
context_layer = torch.empty_like(q)
if cu_seqlens is None:
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device)
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
# operator requires pta version >= 2.5.1
torch_npu._npu_flash_attention_unpad(
query=q,
key=k,
value=v,
seq_len=cu_seqlens,
scale_value=self.head_size**-0.5,
seq_len=seq_lens_cpu,
scale_value=self.scale_value,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=context_layer,
)
if enable_pad:
if self.enable_pad:
context_layer = context_layer[..., :origin_shape]
if is_reshaped: