From c3c91387191602e7ccc521f4f9957369b033ea52 Mon Sep 17 00:00:00 2001
From: XiaoxinWang <963372609@qq.com>
Date: Mon, 10 Nov 2025 17:18:45 +0800
Subject: [PATCH] [Perf] Move attention update stream out of loop to optimize
performance (#3985)
### What this PR does / why we need it?
In the `update_*attn_params` functions, the
`torch.npu.stream(update_stream)` context manager was previously located
inside the for-loop that updates parameters for each layer. This
resulted in redundant stream initiations for every layer, adding
unnecessary overhead.
This commit refactors the code by moving the stream context manager to
wrap the entire for-loop. This ensures that the update stream is
initiated only once per function call, rather than for each layer. This
change reduces 90us in each decode model.
update stream in every layer:
remove update stream in every layer:
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac
Signed-off-by: wangxiaoxin-sherie
Co-authored-by: wangxiaoxin-sherie
---
vllm_ascend/compilation/acl_graph.py | 89 ++++++++++++++--------------
1 file changed, 46 insertions(+), 43 deletions(-)
diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py
index d3e779e..c96b348 100644
--- a/vllm_ascend/compilation/acl_graph.py
+++ b/vllm_ascend/compilation/acl_graph.py
@@ -193,24 +193,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
- for key, param, handle, event in zip(
- forward_context.attn_metadata,
- graph_params.attn_params[runtime_shape],
- graph_params.handles[runtime_shape],
- graph_params.events[runtime_shape],
- ):
- (
- query,
- key_cache,
- value_cache,
- num_kv_heads,
- num_heads,
- scale,
- block_table,
- seq_lens,
- output,
- ) = param
- seq_lens = forward_context.attn_metadata[key].seq_lens
+ with torch.npu.stream(update_stream):
+ for key, param, handle, event in zip(
+ forward_context.attn_metadata,
+ graph_params.attn_params[runtime_shape],
+ graph_params.handles[runtime_shape],
+ graph_params.events[runtime_shape],
+ ):
+ (
+ query,
+ key_cache,
+ value_cache,
+ num_kv_heads,
+ num_heads,
+ scale,
+ block_table,
+ seq_lens,
+ output,
+ ) = param
+ seq_lens = forward_context.attn_metadata[key].seq_lens
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
@@ -253,31 +254,33 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
- for key, param, handle, event in zip(
- forward_context.attn_metadata,
- graph_params.attn_params[runtime_shape],
- graph_params.handles[runtime_shape],
- graph_params.events[runtime_shape],
- ):
- (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
- spec_attn_mask, sparse_mode, scale, block_table, block_size,
- seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
- seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
- if speculative_config and speculative_config.method == "deepseek_mtp":
- actual_seq_lengths = forward_context.attn_metadata[
- key].decode.actual_seq_lengths_q
- spec_multiple = speculative_config.num_speculative_tokens + 1
- seq_lens_list = seq_lens_list + [0] * (
- runtime_shape // spec_multiple - len(seq_lens_list))
- actual_seq_lengths = [
- spec_multiple * (i + 1)
- for i in range(runtime_shape // spec_multiple)
- ]
- else:
- seq_lens_list = seq_lens_list + [0] * (runtime_shape -
- len(seq_lens_list))
- with torch.npu.stream(update_stream):
- torch.npu.graph_task_update_begin(update_stream, handle)
+ with torch.npu.stream(update_stream):
+ for key, param, handle, event in zip(
+ forward_context.attn_metadata,
+ graph_params.attn_params[runtime_shape],
+ graph_params.handles[runtime_shape],
+ graph_params.events[runtime_shape],
+ ):
+ (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
+ spec_attn_mask, sparse_mode, scale, block_table, block_size,
+ seq_lens_list, actual_seq_lengths, attn_output,
+ softmax_lse) = param
+ seq_lens_list = forward_context.attn_metadata[
+ key].decode.seq_lens_list
+ if speculative_config and speculative_config.method == "deepseek_mtp":
+ actual_seq_lengths = forward_context.attn_metadata[
+ key].decode.actual_seq_lengths_q
+ spec_multiple = speculative_config.num_speculative_tokens + 1
+ seq_lens_list = seq_lens_list + [0] * (
+ runtime_shape // spec_multiple - len(seq_lens_list))
+ actual_seq_lengths = [
+ spec_multiple * (i + 1)
+ for i in range(runtime_shape // spec_multiple)
+ ]
+ else:
+ seq_lens_list = seq_lens_list + [0] * (runtime_shape -
+ len(seq_lens_list))
+ torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,