[Bugfix] fix the precision issues that may raise from the inter-layer reuse of the workspace in certain scenarios (#5522)
### What this PR does / why we need it?
In the current process of implementing attention updates, the FIA
operator shares a single workspace among different layers within the
same computation graph. To enable memory reuse, we adopt the
weak_ref_tensor mechanism. However, this approach may lead to precision
anomalies in certain scenarios. To address this issue, different layers
in the same computation graph are assigned independent workspaces.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: WithHades <244036962@qq.com>
This commit is contained in:
@@ -440,8 +440,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
@@ -162,6 +162,13 @@ class ACLGraphWrapper:
|
||||
# any other acl graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the workspaces
|
||||
# to save memory
|
||||
global _graph_params
|
||||
global _draft_graph_params
|
||||
weak_ref_workspaces(_graph_params)
|
||||
weak_ref_workspaces(_draft_graph_params)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
@@ -195,6 +202,16 @@ class ACLGraphWrapper:
|
||||
return entry.output
|
||||
|
||||
|
||||
def weak_ref_workspaces(params):
|
||||
if params is None:
|
||||
return
|
||||
for num_tokens in params.workspaces:
|
||||
if params.workspaces[num_tokens] is None:
|
||||
continue
|
||||
params.workspaces[num_tokens] = weak_ref_tensors(
|
||||
params.workspaces[num_tokens])
|
||||
|
||||
|
||||
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
@@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
|
||||
_graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
|
||||
Reference in New Issue
Block a user