From 1d3544c8875ed66c7dd7c1235c0825bfe21ad730 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:49:22 +0800 Subject: [PATCH] [BugFix]converting pa get_workspace back to capturing (#5833) ### What this PR does / why we need it? This helps to fix a bug in for pa get_workspace. In earlier implementation, we use `_npu_paged_attention_get_workspace` in `_update_pa_attn_params`. However, this might cause some potential memory problems as it dynamically allocate new memory for workspace when calling this api. Therefor, we move this back to capturing, and use a fixed `SEQ_LEN_WITH_MAX_PA_WORKSPACE` to get max workspace. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d Signed-off-by: Angazenn --- vllm_ascend/compilation/acl_graph.py | 21 +-------------------- vllm_ascend/worker/model_runner_v1.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index ea979fbf..c1f1a798 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -229,25 +229,6 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape): ) = param seq_lens = forward_context.attn_metadata[key].seq_lens - # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY - # mode with GQA. This is triggered by getting workspace for _npu_paged_attention - # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens - # might encounter a bigger workspace, while currently we use max_model_len to - # calculate max workspace in capturing. So additional get_workspace is added - # here to avoid such bugs. - # TODO(Angazenn): we will remove this once _npu_paged_attention is fully - # replaced by npu_fused_infer_attention_score which does not contain such bugs. - workspace = torch_npu._npu_paged_attention_get_workspace( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - ) torch.npu.graph_task_update_begin(update_stream, handle) torch_npu._npu_paged_attention( query=query, @@ -259,7 +240,7 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape): block_table=block_table, context_lens=seq_lens, out=output, - workspace=workspace, + workspace=graph_params.workspaces.get(runtime_shape), ) torch.npu.graph_task_update_end(update_stream) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c8dab784..e2fb61f7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -76,7 +76,7 @@ from vllm.v1.worker.utils import AttentionGroup from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention # yapf conflicts with isort for this block # yapf: disable from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, @@ -133,6 +133,9 @@ if get_ascend_device_type() == AscendDeviceType._310P: torch_npu.npu.set_compile_mode(jit_compile=False) +SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144 + + @dataclass class GraphCaptureContext: stream: torch.npu.Stream @@ -1919,6 +1922,7 @@ class NPUModelRunner(GPUModelRunner): num_scheduled_tokens: np.ndarray, aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, + is_graph_capturing: bool = False, ) -> Optional[dict[str, Any]]: attn_metadata: Optional[dict[str, Any]] = None @@ -1928,7 +1932,12 @@ class NPUModelRunner(GPUModelRunner): attn_metadata = {} - seq_lens = max_query_len + # The reason why we use a fixed seq_len rather than max_query_len is that + # _npu_paged_attention_get_workspace only returns max workspace with specific + # seq_lens. We use this seq_len only when capturing graph, and still use max_query_len + # in inference. This will be removed once npu_fused_infer_attention_score + # outperforms _npu_paged_attention on all cases. + seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) else max_query_len self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() @@ -2177,6 +2186,7 @@ class NPUModelRunner(GPUModelRunner): max_query_len=max_query_len, aclgraph_runtime_mode=cudagraph_runtime_mode, force_attention=force_attention, + is_graph_capturing=is_graph_capturing, num_scheduled_tokens=num_scheduled_tokens, )