[Perf] Move attention update stream out of loop to optimize performance (#3985)

### What this PR does / why we need it?
In the `update_*attn_params` functions, the
`torch.npu.stream(update_stream)` context manager was previously located
inside the for-loop that updates parameters for each layer. This
resulted in redundant stream initiations for every layer, adding
unnecessary overhead.

This commit refactors the code by moving the stream context manager to
wrap the entire for-loop. This ensures that the update stream is
initiated only once per function call, rather than for each layer. This
change reduces 90us in each decode model.
update stream in every layer:
<img width="1720" height="383" alt="image"
src="https://github.com/user-attachments/assets/70e4cb69-5bc1-4180-a67d-c99132134be6"
/>

remove update stream in every layer:
<img width="1269" height="175" alt="image"
src="https://github.com/user-attachments/assets/0e290edb-b0ce-48fe-b032-1b924ade6ae5"
/>

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-10 17:18:45 +08:00
committed by GitHub
parent d913f9474b
commit c3c9138719

View File

@@ -193,24 +193,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
@@ -253,31 +254,33 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output,
softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,