[Perf] Improve MLA multistream performance (#1353)

### What this PR does / why we need it?
> Need to merge after PR #1322

According to benchmark results, this PR brings approximately 1%
performance gain.

#### Before Improvement
Profiling
<img width="1147" alt="截屏2025-06-22 14 54 47"
src="https://github.com/user-attachments/assets/4a4dc7f1-5b76-45d5-864d-dd7f8faf993c"
/>

Evaluation
```
# server launch command
python -m vllm.entrypoints.openai.api_server --model=/DeepSeek-R1-W8A8 \
    --quantization ascend \
    --served-model-name auto \
    --trust-remote-code \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=16 \
    --max-num-seqs 24 \
    --max-model-len 32768 \
    --max-num-batched-tokens 8192 \
    --block-size 128 \
    --no-enable-prefix-caching \
    --additional-config '{"torchair_graph_config":{"enable_multistream_mla": true,"enabled":true,"use_cached_graph":true,"graph_batch_sizes":[24]},"ascend_scheduler_config":{"enabled":true},"expert_tensor_parallel_size":16}' \
    --gpu-memory-utilization 0.96

# client benchmark command
python /root/vllm/benchmarks/benchmark_serving.py --backend vllm --dataset-name random \
        --random-input-len 4096 \
        --random-output-len 1536 \
        --num-prompts 200 \
        --ignore-eos \
        --model auto \
        --tokenizer /DeepSeek-R1-W8A8 \
        --port 8006 \
        --request-rate 1 \
        --max-concurrency 24 \
        --save-result \
        --skip-initial-test \
        --metric-percentiles "50,90,99"
```

```
============ Serving Benchmark Result ============
Successful requests:                     200       
Benchmark duration (s):                  958.59    
Total input tokens:                      819200    
Total generated tokens:                  307200    
Request throughput (req/s):              0.2086    
Output token throughput (tok/s):         320.47    
Total Token throughput (tok/s):          1175.05   
---------------Time to First Token----------------
Mean TTFT (ms):                          942.70    
Median TTFT (ms):                        713.87    
P50 TTFT (ms):                           713.87    
P90 TTFT (ms):                           1363.88   
P99 TTFT (ms):                           2008.73   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          68.96     
Median TPOT (ms):                        69.49     
P50 TPOT (ms):                           69.49     
P90 TPOT (ms):                           70.42     
P99 TPOT (ms):                           70.72     
---------------Inter-token Latency----------------
Mean ITL (ms):                           68.96     
Median ITL (ms):                         59.88     
P50 ITL (ms):                            59.88     
P90 ITL (ms):                            61.59     
P99 ITL (ms):                            68.82     
==================================================
```

#### After Improvement
Profiling
<img width="1200" alt="截屏2025-06-22 14 55 42"
src="https://github.com/user-attachments/assets/e3eb9dec-0ff0-4e5f-ab94-93c65003e51f"
/>

Evaluation
```
============ Serving Benchmark Result ============
Successful requests:                     200       
Benchmark duration (s):                  948.08    
Total input tokens:                      819200    
Total generated tokens:                  307200    
Request throughput (req/s):              0.2110    
Output token throughput (tok/s):         324.02    
Total Token throughput (tok/s):          1188.08   
---------------Time to First Token----------------
Mean TTFT (ms):                          1019.25   
Median TTFT (ms):                        714.63    
P50 TTFT (ms):                           714.63    
P90 TTFT (ms):                           1367.31   
P99 TTFT (ms):                           2661.52   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          68.14     
Median TPOT (ms):                        68.68     
P50 TPOT (ms):                           68.68     
P90 TPOT (ms):                           69.33     
P99 TPOT (ms):                           70.30     
---------------Inter-token Latency----------------
Mean ITL (ms):                           68.14     
Median ITL (ms):                         59.04     
P50 ITL (ms):                            59.04     
P90 ITL (ms):                            60.93     
P99 ITL (ms):                            66.89     
==================================================
```
### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?




- vLLM version: v0.9.2
- vLLM main:
65393ee064

Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
ApsarasX
2025-07-11 08:51:17 +08:00
committed by GitHub
parent cc210f46e6
commit 0fc9b56d40
3 changed files with 58 additions and 30 deletions

View File

@@ -21,7 +21,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@@ -579,13 +579,18 @@ class AscendMLAImpl(MLAAttentionImpl):
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
"{32, 64, 128}.")
def _v_up_proj_and_o_proj(self, x):
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,
x,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
return self.o_proj(x, is_prefill=False)[0]
# Return `ql_nope`, `q_pe`
@@ -864,7 +869,6 @@ class AscendMLAImpl(MLAAttentionImpl):
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
enable_multistream_mla: bool = False,
):
B = hidden_states.shape[0]
@@ -874,21 +878,18 @@ class AscendMLAImpl(MLAAttentionImpl):
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope, kv
def exec_kv_prefill(
self,
@@ -940,6 +941,7 @@ class AscendMLAImpl(MLAAttentionImpl):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
@@ -1020,7 +1022,8 @@ class AscendMLAImpl(MLAAttentionImpl):
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output)
return self._v_up_proj_and_o_proj(attn_output,
enable_multistream_mla)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
@@ -1037,6 +1040,7 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_metadata: M,
output: Optional[torch.Tensor] = None,
enable_multistream_mla: bool = False,
ckq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
@@ -1091,6 +1095,15 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,
ckq,
enabled=enable_multistream_mla)
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
@@ -1100,12 +1113,13 @@ class AscendMLAImpl(MLAAttentionImpl):
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
decode_kv,
enabled=enable_multistream_mla)
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping, enable_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
@@ -1194,7 +1208,8 @@ class AscendMLAImpl(MLAAttentionImpl):
if self.running_in_graph:
return self._forward_decode(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
kv_cache, attn_metadata,
enable_multistream_mla)
else:
output_decode = self._forward_decode(decode_ql_nope,
decode_q_pe,

View File

@@ -74,8 +74,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
from vllm_ascend.utils import dispose_tensor, npu_prefetch
class CustomDeepseekV2SiluAndMul(SiluAndMul):
@@ -567,12 +566,12 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None:
npu_prefetch(self.q_a_proj.weight,
hidden_states,
enabled=enable_multistream_mla)
ckq = self.q_a_proj(hidden_states)[0]
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
hidden_states_or_q_c = self.q_a_layernorm(ckq)
forward_kwargs['ckq'] = ckq
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled:

View File

@@ -416,6 +416,20 @@ def npu_wait_tensor(self: torch.Tensor,
return _npu_wait_tensor(self, dependency) if enabled else self
# TODO(wxy): Move to ops module
def npu_prefetch(input: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
*,
enabled: bool = True):
if not enabled:
return
input_size = input.element_size() * input.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(input, dependency, max_size)
# TODO(zzzzwwjj): move this into forward_context
class FusedMoEState(Enum):
AllGather = 0