From f9494d978ab82b9b790ae59e49ebcf4330281b09 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:08:57 +0800 Subject: [PATCH] [cherry-pick][v0.11.0-dev][bugfix] Fix a rare bug triggered by _npu_paged_attention in FULL_DECODE_ONLY mode (#3987) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This is cherry-pick from #3986 . This PR fixes a bug where the workspace of `_npu_paged_attention` in setup is smaller than execution. For current implementation of FULL_DECODE_ONLY with `_npu_paged_attention`, we use `_npu_paged_attention_get_workspace` when capturing with `max_model_len` as `seq_lens`. This assumes that PA with larger `seq_lens` inputs should have larger workspace than smaller `seq_lens`. However, there are rare cases where PA with smaller `seq_lens` incurs larger space. So I add `get_workspace` directly into `update_attn_params`. This change might introduce slight(≈1%) performance degradation for small num_tokens(such as 1) in decode phase, and there is no other known memory issues. So I think this change is acceptable. We can remove this if new attention op (such as `npu_fused_infer_attention_score`) does not have such problems. Signed-off-by: Angazenn --- vllm_ascend/compilation/acl_graph.py | 40 ++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index f2cab32..d3e779e 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -212,19 +212,37 @@ def update_attn_params(update_stream, forward_context, runtime_shape): ) = param seq_lens = forward_context.attn_metadata[key].seq_lens + # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY + # mode with GQA. This is triggered by getting workspace for _npu_paged_attention + # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens + # might encounter a bigger workspace, while currently we use max_model_len to + # calculate max workspace in capturing. So additional get_workspace is added + # here to avoid such bugs. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=graph_params.workspaces.get(runtime_shape)) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) torch.npu.graph_task_update_end(update_stream) event.record(update_stream)