From e0d58d543b7119980c0326fc37664be7ff2ca79f Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:08:07 +0800 Subject: [PATCH] [main][bugfix] Fix a rare bug triggered by _npu_paged_attention in FULL_DECODE_ONLY mode (#3986) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR fixes a bug where the workspace of `_npu_paged_attention` in setup is smaller than execution. For current implementation of FULL_DECODE_ONLY with `_npu_paged_attention`, we use `_npu_paged_attention_get_workspace` when capturing with `max_model_len` as `seq_lens`. This assumes that PA with larger `seq_lens` inputs should have larger workspace than smaller `seq_lens`. However, there are rare cases where PA with smaller `seq_lens` incurs larger space. So I add `get_workspace` directly into `update_attn_params`. This change might introduce small(≈1%) performance degradation for low num_tokens(such as 1) in decode phase, and there is no other known memory issues. So I think this change is acceptable. We can remove this if new attention op (such as `npu_fused_infer_attention_score`) does not have such problems. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: Angazenn --- vllm_ascend/compilation/acl_graph.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 82410ec8..af6322ab 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -214,8 +214,16 @@ def update_attn_params(update_stream, forward_context, runtime_shape): output, ) = param seq_lens = forward_context.attn_metadata[key].seq_lens - torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention( + + # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY + # mode with GQA. This is triggered by getting workspace for _npu_paged_attention + # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens + # might encounter a bigger workspace, while currently we use max_model_len to + # calculate max workspace in capturing. So additional get_workspace is added + # here to avoid such bugs. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + workspace = torch_npu._npu_paged_attention_get_workspace( query=query, key_cache=key_cache, value_cache=value_cache, @@ -224,8 +232,18 @@ def update_attn_params(update_stream, forward_context, runtime_shape): scale_value=scale, block_table=block_table, context_lens=seq_lens, - out=output, - workspace=graph_params.workspaces.get(runtime_shape)) + out=output) + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) torch.npu.graph_task_update_end(update_stream) event.record(update_stream)