From 03679cf1d38949eabb1cfeb53c02996e9b124117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=97=A0=E8=84=B8=E7=94=B7?= <244036962@qq.com> Date: Wed, 31 Dec 2025 16:54:04 +0800 Subject: [PATCH] [Bugfix] fix the precision issues that may raise from the inter-layer reuse of the workspace in certain scenarios (#5522) ### What this PR does / why we need it? In the current process of implementing attention updates, the FIA operator shares a single workspace among different layers within the same computation graph. To enable memory reuse, we adopt the weak_ref_tensor mechanism. However, this approach may lead to precision anomalies in certain scenarios. To address this issue, different layers in the same computation graph are assigned independent workspaces. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 Signed-off-by: WithHades <244036962@qq.com> --- vllm_ascend/attention/attention_v1.py | 3 +-- vllm_ascend/compilation/acl_graph.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d99f15cd..1405ed9f 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -440,8 +440,7 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) # Handle graph capturing mode stream = torch_npu.npu.current_stream() diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 3f28a3a3..72cf925d 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -162,6 +162,13 @@ class ACLGraphWrapper: # any other acl graph. output = weak_ref_tensors(output) + # here we always use weak ref for the workspaces + # to save memory + global _graph_params + global _draft_graph_params + weak_ref_workspaces(_graph_params) + weak_ref_workspaces(_draft_graph_params) + # here we always use weak ref for the output # to save memory entry.output = weak_ref_tensors(output) @@ -195,6 +202,16 @@ class ACLGraphWrapper: return entry.output +def weak_ref_workspaces(params): + if params is None: + return + for num_tokens in params.workspaces: + if params.workspaces[num_tokens] is None: + continue + params.workspaces[num_tokens] = weak_ref_tensors( + params.workspaces[num_tokens]) + + def _update_attn_pa_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args @@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]): def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) + _graph_params.workspaces[num_tokens] = workspace def get_graph_params():