[Bugfix] fix the precision issues that may raise from the inter-layer reuse of the workspace in certain scenarios (#5522)
### What this PR does / why we need it?
In the current process of implementing attention updates, the FIA
operator shares a single workspace among different layers within the
same computation graph. To enable memory reuse, we adopt the
weak_ref_tensor mechanism. However, this approach may lead to precision
anomalies in certain scenarios. To address this issue, different layers
in the same computation graph are assigned independent workspaces.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: WithHades <244036962@qq.com>
This commit is contained in:
@@ -440,8 +440,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
block_table=attn_metadata.block_tables,
|
block_table=attn_metadata.block_tables,
|
||||||
context_lens=attn_metadata.seq_lens,
|
context_lens=attn_metadata.seq_lens,
|
||||||
out=output)
|
out=output)
|
||||||
update_graph_params_workspaces(num_tokens,
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
weak_ref_tensors(workspace))
|
|
||||||
|
|
||||||
# Handle graph capturing mode
|
# Handle graph capturing mode
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|||||||
@@ -162,6 +162,13 @@ class ACLGraphWrapper:
|
|||||||
# any other acl graph.
|
# any other acl graph.
|
||||||
output = weak_ref_tensors(output)
|
output = weak_ref_tensors(output)
|
||||||
|
|
||||||
|
# here we always use weak ref for the workspaces
|
||||||
|
# to save memory
|
||||||
|
global _graph_params
|
||||||
|
global _draft_graph_params
|
||||||
|
weak_ref_workspaces(_graph_params)
|
||||||
|
weak_ref_workspaces(_draft_graph_params)
|
||||||
|
|
||||||
# here we always use weak ref for the output
|
# here we always use weak ref for the output
|
||||||
# to save memory
|
# to save memory
|
||||||
entry.output = weak_ref_tensors(output)
|
entry.output = weak_ref_tensors(output)
|
||||||
@@ -195,6 +202,16 @@ class ACLGraphWrapper:
|
|||||||
return entry.output
|
return entry.output
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_workspaces(params):
|
||||||
|
if params is None:
|
||||||
|
return
|
||||||
|
for num_tokens in params.workspaces:
|
||||||
|
if params.workspaces[num_tokens] is None:
|
||||||
|
continue
|
||||||
|
params.workspaces[num_tokens] = weak_ref_tensors(
|
||||||
|
params.workspaces[num_tokens])
|
||||||
|
|
||||||
|
|
||||||
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||||
@@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
|
|||||||
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
|
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
|
||||||
global _graph_params
|
global _graph_params
|
||||||
if _graph_params is not None:
|
if _graph_params is not None:
|
||||||
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
|
_graph_params.workspaces[num_tokens] = workspace
|
||||||
|
|
||||||
|
|
||||||
def get_graph_params():
|
def get_graph_params():
|
||||||
|
|||||||
Reference in New Issue
Block a user