[Perf] Move attention update stream out of loop to optimize performance (#3848)

### What this PR does / why we need it?
In the `update_*attn_params` functions, the
`torch.npu.stream(update_stream)` context manager was previously located
inside the for-loop that updates parameters for each layer. This
resulted in redundant stream initiations for every layer, adding
unnecessary overhead.

This commit refactors the code by moving the stream context manager to
wrap the entire for-loop. This ensures that the update stream is
initiated only once per function call, rather than for each layer. This
change reduces 90us in each decode model.
update stream in every layer:
<img width="1720" height="383" alt="image"
src="https://github.com/user-attachments/assets/70e4cb69-5bc1-4180-a67d-c99132134be6"
/>

remove update stream in every layer:
<img width="1269" height="175" alt="image"
src="https://github.com/user-attachments/assets/0e290edb-b0ce-48fe-b032-1b924ade6ae5"
/>

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-03 09:19:57 +08:00
committed by GitHub
parent d0cc9c1203
commit d4c75088a0

View File

@@ -194,26 +194,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
for key, param, handle, event in zip( with torch.npu.stream(update_stream):
forward_context.attn_metadata, for key, param, handle, event in zip(
graph_params.attn_params[runtime_shape], forward_context.attn_metadata,
graph_params.handles[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.events[runtime_shape], graph_params.handles[runtime_shape],
): graph_params.events[runtime_shape],
( ):
query, (
key_cache, query,
value_cache, key_cache,
num_kv_heads, value_cache,
num_heads, num_kv_heads,
scale, num_heads,
block_table, scale,
seq_lens, block_table,
output, seq_lens,
) = param output,
seq_lens = forward_context.attn_metadata[key].seq_lens ) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention( torch_npu._npu_paged_attention(
query=query, query=query,
@@ -236,30 +235,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
for key, param, handle, event in zip( with torch.npu.stream(update_stream):
forward_context.attn_metadata, for key, param, handle, event in zip(
graph_params.attn_params[runtime_shape], forward_context.attn_metadata,
graph_params.handles[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.events[runtime_shape], graph_params.handles[runtime_shape],
): graph_params.events[runtime_shape],
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, ):
spec_attn_mask, sparse_mode, scale, block_table, block_size, (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list seq_lens_list, actual_seq_lengths, attn_output,
if speculative_config and speculative_config.method == "deepseek_mtp": softmax_lse) = param
actual_seq_lengths = forward_context.attn_metadata[ seq_lens_list = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q key].decode.seq_lens_list
spec_multiple = speculative_config.num_speculative_tokens + 1 if speculative_config and speculative_config.method == "deepseek_mtp":
seq_lens_list = seq_lens_list + [0] * ( actual_seq_lengths = forward_context.attn_metadata[
runtime_shape // spec_multiple - len(seq_lens_list)) key].decode.actual_seq_lengths_q
actual_seq_lengths = [ spec_multiple = speculative_config.num_speculative_tokens + 1
spec_multiple * (i + 1) seq_lens_list = seq_lens_list + [0] * (
for i in range(runtime_shape // spec_multiple) runtime_shape // spec_multiple - len(seq_lens_list))
] actual_seq_lengths = [
else: spec_multiple * (i + 1)
seq_lens_list = seq_lens_list + [0] * (runtime_shape - for i in range(runtime_shape // spec_multiple)
len(seq_lens_list)) ]
with torch.npu.stream(update_stream): else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
@@ -291,26 +292,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
for key, param, handle, event in zip( with torch.npu.stream(update_stream):
forward_context.attn_metadata, for key, param, handle, event in zip(
graph_params.attn_params[runtime_shape], forward_context.attn_metadata,
graph_params.handles[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.events[runtime_shape], graph_params.handles[runtime_shape],
): graph_params.events[runtime_shape],
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table, ):
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank, (q_nope, k_nope, value, num_heads, num_kv_heads, scale,
dcp_rank, dcp_size) = param block_table, block_size, actual_seq_lengths_kv, attn_output,
actual_seq_lengths_kv = forward_context.attn_metadata[ softmax_lse, cp_rank, dcp_rank, dcp_size) = param
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank, actual_seq_lengths_kv = forward_context.attn_metadata[
dcp_rank] key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
pad_length = runtime_shape - len(actual_seq_lengths_kv) dcp_rank]
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype) pad_length = runtime_shape - len(actual_seq_lengths_kv)
actual_seq_lengths_kv = np.concatenate( pad_tensor = np.zeros(pad_length,
[actual_seq_lengths_kv, pad_tensor]) dtype=actual_seq_lengths_kv.dtype)
if dcp_size > 1: actual_seq_lengths_kv = np.concatenate(
num_heads = num_heads * dcp_size [actual_seq_lengths_kv, pad_tensor])
if dcp_size > 1:
num_heads = num_heads * dcp_size
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
@@ -340,30 +342,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
for key, param, handle, event in zip( with torch.npu.stream(update_stream):
forward_context.attn_metadata, for key, param, handle, event in zip(
graph_params.attn_params[runtime_shape], forward_context.attn_metadata,
graph_params.handles[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.events[runtime_shape], graph_params.handles[runtime_shape],
): graph_params.events[runtime_shape],
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale, ):
num_kv_heads, attn_output, softmax_lse) = param (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
scale, num_kv_heads, attn_output, softmax_lse) = param
decode_meta = forward_context.attn_metadata[key].decode decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len seq_len = decode_meta.cp_seq_len
if speculative_config and speculative_config.method == "deepseek_mtp": if speculative_config and speculative_config.method == "deepseek_mtp":
spec_multiple = speculative_config.num_speculative_tokens + 1 spec_multiple = speculative_config.num_speculative_tokens + 1
seq_len = seq_len + [0] * (runtime_shape // spec_multiple - seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
len(seq_len)) len(seq_len))
else: else:
pad_length = runtime_shape - len(seq_len) pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length, pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype, dtype=seq_len.dtype,
device=seq_len.device) device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0) seq_len = torch.cat([seq_len, pad_tensor], dim=0)
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.atb.npu_multi_head_latent_attention( torch_npu.atb.npu_multi_head_latent_attention(