[Bugfix] fix the precision issues that may raise from the inter-layer reuse of the workspace in certain scenarios (#5522)

### What this PR does / why we need it?

In the current process of implementing attention updates, the FIA
operator shares a single workspace among different layers within the
same computation graph. To enable memory reuse, we adopt the
weak_ref_tensor mechanism. However, this approach may lead to precision
anomalies in certain scenarios. To address this issue, different layers
in the same computation graph are assigned independent workspaces.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

Signed-off-by: WithHades <244036962@qq.com>
This commit is contained in:
无脸男
2025-12-31 16:54:04 +08:00
committed by GitHub
parent 46a1614387
commit 03679cf1d3
2 changed files with 19 additions and 3 deletions

View File

@@ -440,8 +440,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()

View File

@@ -162,6 +162,13 @@ class ACLGraphWrapper:
# any other acl graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the workspaces
# to save memory
global _graph_params
global _draft_graph_params
weak_ref_workspaces(_graph_params)
weak_ref_workspaces(_draft_graph_params)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
@@ -195,6 +202,16 @@ class ACLGraphWrapper:
return entry.output
def weak_ref_workspaces(params):
if params is None:
return
for num_tokens in params.workspaces:
if params.workspaces[num_tokens] is None:
continue
params.workspaces[num_tokens] = weak_ref_tensors(
params.workspaces[num_tokens])
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
@@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
_graph_params.workspaces[num_tokens] = workspace
def get_graph_params():