diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 98c8c57..bf881b1 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -39,6 +39,8 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) +from ..utils import weak_ref_tensors + class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -402,15 +404,15 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.events[num_tokens].append(event) graph_params.attn_params[num_tokens].append(( - query, - self.key_cache, - self.value_cache, + weak_ref_tensors(query), + weak_ref_tensors(self.key_cache), + weak_ref_tensors(self.value_cache), self.num_kv_heads, self.num_heads, self.scale, - attn_metadata.block_tables, + weak_ref_tensors(attn_metadata.block_tables), attn_metadata.seq_lens, - output, + weak_ref_tensors(output), )) torch.npu.graph_task_group_begin(stream)