From d4c75088a0bee0b3519662465b80fcd3b78bce86 Mon Sep 17 00:00:00 2001 From: XiaoxinWang <963372609@qq.com> Date: Mon, 3 Nov 2025 09:19:57 +0800 Subject: [PATCH] [Perf] Move attention update stream out of loop to optimize performance (#3848) ### What this PR does / why we need it? In the `update_*attn_params` functions, the `torch.npu.stream(update_stream)` context manager was previously located inside the for-loop that updates parameters for each layer. This resulted in redundant stream initiations for every layer, adding unnecessary overhead. This commit refactors the code by moving the stream context manager to wrap the entire for-loop. This ensures that the update stream is initiated only once per function call, rather than for each layer. This change reduces 90us in each decode model. update stream in every layer: image remove update stream in every layer: image ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: wangxiaoxin-sherie Co-authored-by: wangxiaoxin-sherie --- vllm_ascend/compilation/acl_graph.py | 170 ++++++++++++++------------- 1 file changed, 86 insertions(+), 84 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 17ca56b7..41476ccc 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -194,26 +194,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens - - with torch.npu.stream(update_stream): + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens torch.npu.graph_task_update_begin(update_stream, handle) torch_npu._npu_paged_attention( query=query, @@ -236,30 +235,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, - spec_attn_mask, sparse_mode, scale, block_table, block_size, - seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param - seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list - if speculative_config and speculative_config.method == "deepseek_mtp": - actual_seq_lengths = forward_context.attn_metadata[ - key].decode.actual_seq_lengths_q - spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_lens_list = seq_lens_list + [0] * ( - runtime_shape // spec_multiple - len(seq_lens_list)) - actual_seq_lengths = [ - spec_multiple * (i + 1) - for i in range(runtime_shape // spec_multiple) - ] - else: - seq_lens_list = seq_lens_list + [0] * (runtime_shape - - len(seq_lens_list)) - with torch.npu.stream(update_stream): + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + spec_attn_mask, sparse_mode, scale, block_table, block_size, + seq_lens_list, actual_seq_lengths, attn_output, + softmax_lse) = param + seq_lens_list = forward_context.attn_metadata[ + key].decode.seq_lens_list + if speculative_config and speculative_config.method == "deepseek_mtp": + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_lens_list = seq_lens_list + [0] * ( + runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [ + spec_multiple * (i + 1) + for i in range(runtime_shape // spec_multiple) + ] + else: + seq_lens_list = seq_lens_list + [0] * (runtime_shape - + len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( @@ -291,26 +292,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - (q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table, - block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank, - dcp_rank, dcp_size) = param - actual_seq_lengths_kv = forward_context.attn_metadata[ - key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank, - dcp_rank] - pad_length = runtime_shape - len(actual_seq_lengths_kv) - pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype) - actual_seq_lengths_kv = np.concatenate( - [actual_seq_lengths_kv, pad_tensor]) - if dcp_size > 1: - num_heads = num_heads * dcp_size + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, value, num_heads, num_kv_heads, scale, + block_table, block_size, actual_seq_lengths_kv, attn_output, + softmax_lse, cp_rank, dcp_rank, dcp_size) = param + actual_seq_lengths_kv = forward_context.attn_metadata[ + key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank, + dcp_rank] + pad_length = runtime_shape - len(actual_seq_lengths_kv) + pad_tensor = np.zeros(pad_length, + dtype=actual_seq_lengths_kv.dtype) + actual_seq_lengths_kv = np.concatenate( + [actual_seq_lengths_kv, pad_tensor]) + if dcp_size > 1: + num_heads = num_heads * dcp_size - with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( @@ -340,30 +342,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale, - num_kv_heads, attn_output, softmax_lse) = param + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, + scale, num_kv_heads, attn_output, softmax_lse) = param - decode_meta = forward_context.attn_metadata[key].decode - seq_len = decode_meta.cp_seq_len + decode_meta = forward_context.attn_metadata[key].decode + seq_len = decode_meta.cp_seq_len - if speculative_config and speculative_config.method == "deepseek_mtp": - spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_len = seq_len + [0] * (runtime_shape // spec_multiple - - len(seq_len)) - else: - pad_length = runtime_shape - len(seq_len) - pad_tensor = torch.zeros(pad_length, - dtype=seq_len.dtype, - device=seq_len.device) - seq_len = torch.cat([seq_len, pad_tensor], dim=0) + if speculative_config and speculative_config.method == "deepseek_mtp": + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_len = seq_len + [0] * (runtime_shape // spec_multiple - + len(seq_len)) + else: + pad_length = runtime_shape - len(seq_len) + pad_tensor = torch.zeros(pad_length, + dtype=seq_len.dtype, + device=seq_len.device) + seq_len = torch.cat([seq_len, pad_tensor], dim=0) - with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.atb.npu_multi_head_latent_attention(