[main][bugfix] Fix a rare bug triggered by _npu_paged_attention in FULL_DECODE_ONLY mode (#3986)

### What this PR does / why we need it?
This PR fixes a bug where the workspace of `_npu_paged_attention` in
setup is smaller than execution. For current implementation of
FULL_DECODE_ONLY with `_npu_paged_attention`, we use
`_npu_paged_attention_get_workspace` when capturing with `max_model_len`
as `seq_lens`. This assumes that PA with larger `seq_lens` inputs should
have larger workspace than smaller `seq_lens`. However, there are rare
cases where PA with smaller `seq_lens` incurs larger space. So I add
`get_workspace` directly into `update_attn_params`.
This change might introduce small(≈1%) performance degradation for low
num_tokens(such as 1) in decode phase, and there is no other known
memory issues. So I think this change is acceptable. We can remove this
if new attention op (such as `npu_fused_infer_attention_score`) does not
have such problems.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2025-11-06 23:08:07 +08:00
committed by GitHub
parent 1804b60ec8
commit e0d58d543b

View File

@@ -214,8 +214,16 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
# might encounter a bigger workspace, while currently we use max_model_len to
# calculate max workspace in capturing. So additional get_workspace is added
# here to avoid such bugs.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
@@ -224,8 +232,18 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=graph_params.workspaces.get(runtime_shape))
out=output)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)