diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d99f15cd..1405ed9f 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -440,8 +440,7 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) # Handle graph capturing mode stream = torch_npu.npu.current_stream() diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 3f28a3a3..72cf925d 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -162,6 +162,13 @@ class ACLGraphWrapper: # any other acl graph. output = weak_ref_tensors(output) + # here we always use weak ref for the workspaces + # to save memory + global _graph_params + global _draft_graph_params + weak_ref_workspaces(_graph_params) + weak_ref_workspaces(_draft_graph_params) + # here we always use weak ref for the output # to save memory entry.output = weak_ref_tensors(output) @@ -195,6 +202,16 @@ class ACLGraphWrapper: return entry.output +def weak_ref_workspaces(params): + if params is None: + return + for num_tokens in params.workspaces: + if params.workspaces[num_tokens] is None: + continue + params.workspaces[num_tokens] = weak_ref_tensors( + params.workspaces[num_tokens]) + + def _update_attn_pa_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args @@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]): def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) + _graph_params.workspaces[num_tokens] = workspace def get_graph_params():