[Perf] Improve MLA multistream performance (#1353)
### What this PR does / why we need it?
> Need to merge after PR #1322
According to benchmark results, this PR brings approximately 1%
performance gain.
#### Before Improvement
Profiling
<img width="1147" alt="截屏2025-06-22 14 54 47"
src="https://github.com/user-attachments/assets/4a4dc7f1-5b76-45d5-864d-dd7f8faf993c"
/>
Evaluation
```
# server launch command
python -m vllm.entrypoints.openai.api_server --model=/DeepSeek-R1-W8A8 \
--quantization ascend \
--served-model-name auto \
--trust-remote-code \
--distributed-executor-backend=mp \
--port 8006 \
-tp=16 \
--max-num-seqs 24 \
--max-model-len 32768 \
--max-num-batched-tokens 8192 \
--block-size 128 \
--no-enable-prefix-caching \
--additional-config '{"torchair_graph_config":{"enable_multistream_mla": true,"enabled":true,"use_cached_graph":true,"graph_batch_sizes":[24]},"ascend_scheduler_config":{"enabled":true},"expert_tensor_parallel_size":16}' \
--gpu-memory-utilization 0.96
# client benchmark command
python /root/vllm/benchmarks/benchmark_serving.py --backend vllm --dataset-name random \
--random-input-len 4096 \
--random-output-len 1536 \
--num-prompts 200 \
--ignore-eos \
--model auto \
--tokenizer /DeepSeek-R1-W8A8 \
--port 8006 \
--request-rate 1 \
--max-concurrency 24 \
--save-result \
--skip-initial-test \
--metric-percentiles "50,90,99"
```
```
============ Serving Benchmark Result ============
Successful requests: 200
Benchmark duration (s): 958.59
Total input tokens: 819200
Total generated tokens: 307200
Request throughput (req/s): 0.2086
Output token throughput (tok/s): 320.47
Total Token throughput (tok/s): 1175.05
---------------Time to First Token----------------
Mean TTFT (ms): 942.70
Median TTFT (ms): 713.87
P50 TTFT (ms): 713.87
P90 TTFT (ms): 1363.88
P99 TTFT (ms): 2008.73
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 68.96
Median TPOT (ms): 69.49
P50 TPOT (ms): 69.49
P90 TPOT (ms): 70.42
P99 TPOT (ms): 70.72
---------------Inter-token Latency----------------
Mean ITL (ms): 68.96
Median ITL (ms): 59.88
P50 ITL (ms): 59.88
P90 ITL (ms): 61.59
P99 ITL (ms): 68.82
==================================================
```
#### After Improvement
Profiling
<img width="1200" alt="截屏2025-06-22 14 55 42"
src="https://github.com/user-attachments/assets/e3eb9dec-0ff0-4e5f-ab94-93c65003e51f"
/>
Evaluation
```
============ Serving Benchmark Result ============
Successful requests: 200
Benchmark duration (s): 948.08
Total input tokens: 819200
Total generated tokens: 307200
Request throughput (req/s): 0.2110
Output token throughput (tok/s): 324.02
Total Token throughput (tok/s): 1188.08
---------------Time to First Token----------------
Mean TTFT (ms): 1019.25
Median TTFT (ms): 714.63
P50 TTFT (ms): 714.63
P90 TTFT (ms): 1367.31
P99 TTFT (ms): 2661.52
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 68.14
Median TPOT (ms): 68.68
P50 TPOT (ms): 68.68
P90 TPOT (ms): 69.33
P99 TPOT (ms): 70.30
---------------Inter-token Latency----------------
Mean ITL (ms): 68.14
Median ITL (ms): 59.04
P50 ITL (ms): 59.04
P90 ITL (ms): 60.93
P99 ITL (ms): 66.89
==================================================
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.9.2
- vLLM main:
65393ee064
Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -21,7 +21,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -579,13 +579,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
|
||||
"{32, 64, 128}.")
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
x,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
return self.o_proj(x, is_prefill=False)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
@@ -864,7 +869,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
enable_multistream_mla: bool = False,
|
||||
):
|
||||
|
||||
B = hidden_states.shape[0]
|
||||
@@ -874,21 +878,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=enable_multistream_mla):
|
||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
return k_pe, k_nope
|
||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
return k_pe, k_nope, kv
|
||||
|
||||
def exec_kv_prefill(
|
||||
self,
|
||||
@@ -940,6 +941,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
enable_multistream_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
@@ -1020,7 +1022,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
out=attn_output)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj_and_o_proj(attn_output)
|
||||
return self._v_up_proj_and_o_proj(attn_output,
|
||||
enable_multistream_mla)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
@@ -1037,6 +1040,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
enable_multistream_mla: bool = False,
|
||||
ckq: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
@@ -1091,6 +1095,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
sin = sin[attn_metadata.decode.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=enable_multistream_mla):
|
||||
npu_wait_tensor(hidden_states_or_kv_c_normed,
|
||||
ckq,
|
||||
enabled=enable_multistream_mla)
|
||||
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping)
|
||||
# Without explicitly controlling the order, IndexByTensor operations
|
||||
# would be placed after `matmul W_KV_T` hindering the overlapping of
|
||||
# KvRmsNormRopeCache and SingleRope.
|
||||
@@ -1100,12 +1113,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
npu_wait_tensor(decode_hs_or_q_c,
|
||||
sin,
|
||||
enabled=enable_multistream_mla)
|
||||
npu_wait_tensor(decode_hs_or_q_c,
|
||||
decode_kv,
|
||||
enabled=enable_multistream_mla)
|
||||
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
if self.running_in_graph:
|
||||
decode_k_pe, decode_k_nope = self.exec_kv(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping, enable_multistream_mla)
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=enable_multistream_mla):
|
||||
@@ -1194,7 +1208,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if self.running_in_graph:
|
||||
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
||||
decode_k_nope, decode_k_pe,
|
||||
kv_cache, attn_metadata)
|
||||
kv_cache, attn_metadata,
|
||||
enable_multistream_mla)
|
||||
else:
|
||||
output_decode = self._forward_decode(decode_ql_nope,
|
||||
decode_q_pe,
|
||||
|
||||
@@ -74,8 +74,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
from vllm_ascend.utils import dispose_tensor, npu_prefetch
|
||||
|
||||
|
||||
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
||||
@@ -567,12 +566,12 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
and attn_metadata.num_decodes > 0)
|
||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
||||
if self.q_lora_rank is not None:
|
||||
npu_prefetch(self.q_a_proj.weight,
|
||||
hidden_states,
|
||||
enabled=enable_multistream_mla)
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=enable_multistream_mla):
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
forward_kwargs['ckq'] = ckq
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
if self.torchair_graph_enabled:
|
||||
|
||||
@@ -416,6 +416,20 @@ def npu_wait_tensor(self: torch.Tensor,
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
def npu_prefetch(input: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
if not enabled:
|
||||
return
|
||||
input_size = input.element_size() * input.numel()
|
||||
if max_size <= 0 or max_size > input_size:
|
||||
max_size = input_size
|
||||
torch_npu.npu_prefetch(input, dependency, max_size)
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): move this into forward_context
|
||||
class FusedMoEState(Enum):
|
||||
AllGather = 0
|
||||
|
||||
Reference in New Issue
Block a user